Browse Source

[bugfix] dist_trainer's save & load

tags/v0.4.10
yunfan 5 years ago
parent
commit
05eb499eb8
4 changed files with 22 additions and 20 deletions
  1. +1
    -2
      fastNLP/core/__init__.py
  2. +1
    -1
      fastNLP/core/callback.py
  3. +17
    -16
      fastNLP/core/dist_trainer.py
  4. +3
    -1
      test/core/test_dist_trainer.py

+ 1
- 2
fastNLP/core/__init__.py View File

@@ -49,7 +49,6 @@ __all__ = [
"WarmupCallback",
'SaveModelCallback',
"EchoCallback",
"TesterCallback",
"CallbackException",
"EarlyStopError",
@@ -79,7 +78,7 @@ from ._logger import logger
from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
TesterCallback, CallbackException, EarlyStopError
CallbackException, EarlyStopError
from .const import Const
from .dataset import DataSet
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder


+ 1
- 1
fastNLP/core/callback.py View File

@@ -63,7 +63,7 @@ __all__ = [
"WarmupCallback",
"SaveModelCallback",
"EchoCallback",
"TesterCallback",
"_TesterCallback",
"CallbackException",
"EarlyStopError"


+ 17
- 16
fastNLP/core/dist_trainer.py View File

@@ -17,7 +17,8 @@ from tqdm import tqdm

from ._logger import logger
from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, _TesterCallback
from .callback import DistCallbackManager, CallbackException
from .callback import _TesterCallback
from .dataset import DataSet
from .losses import _prepare_losser
from .optimizer import Optimizer
@@ -174,13 +175,13 @@ class DistTrainer():
cb = _TesterCallback(
dev_data, model, metrics,
batch_size=batch_size_per_gpu, num_workers=num_workers)
self.test_manager.add_callback([cb], master=True)
self.test_manager.add_callback([cb], master=False)

# Setup logging
dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
if self.save_path:
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time)
self.cp_save_path = os.path.join(self.save_path, 'checkpoints')
else:
self.cp_save_path = None

@@ -286,11 +287,11 @@ class DistTrainer():
results['seconds'] = round(time.time() - start_time, 2)
self.logger.info("###### Train finished ######")
self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
if load_best_model:
self.load_check_point('best_{}'.format(self.metric_key))
if load_best_model and self.cp_save_path and len(self.test_manager.callbacks):
self.load_check_point('best')
finally:
pass
dist.barrier()
return results

def _train(self):
@@ -417,29 +418,29 @@ class DistTrainer():
def load_check_point(self, name):
path = os.path.join(self.cp_save_path, name)
self.logger.info('reload best model from %s', path)
model_load = torch.load(path)
model_load = torch.load(path, map_location='cpu')
if not isinstance(model_load, dict):
model_load = model_load.state_dict()
self.model.load_state_dict(model_load)
self.model.module.load_state_dict(model_load)

def _do_validation(self):
self.callback_manager.on_valid_begin()
# do evaluate on all nodes
eval_res = self.test_manager.on_valid_begin()
eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res):
eval_res, is_better = list(zip(*eval_res))
else:
eval_res, is_better = None, None
# save better model on master node
if self.is_master and is_better is not None and self.cp_save_path:
for i, better_flag in enumerate(is_better):
if better_flag:
# TODO to support multiple datasets to evaluate
self.save_check_point('best')
break
self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better)

# save better model
for i, better_flag in enumerate(is_better):
if better_flag:
# TODO to support multiple datasets to evaluate
name = 'best_{}'.format(self.metric_key)
self.save_check_point(name)
break
dist.barrier()

def close(self):


+ 3
- 1
test/core/test_dist_trainer.py View File

@@ -130,12 +130,14 @@ class TestDistTrainer(unittest.TestCase):
train_set, model, optimizer=SGD(lr=0.1),
loss=BCELoss(pred="predict", target="y"),
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=self.save_path,
)
trainer.train()
"""
# 应该正确运行
"""
if trainer.is_master and os.path.exists(self.save_path):
shutil.rmtree(self.save_path)

def run_dist(self, run_id):
if torch.cuda.is_available():


Loading…
Cancel
Save