Browse Source

[to #42322933] fix bug: checkpoint hook and bestckpthook exists at the same time

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10227608
master
yuze.zyz yingda.chen 3 years ago
parent
commit
357a233ee3
3 changed files with 21 additions and 8 deletions
  1. +19
    -0
      modelscope/trainers/default_config.py
  2. +2
    -5
      modelscope/trainers/trainer.py
  3. +0
    -3
      tests/trainers/hooks/test_checkpoint_hook.py

+ 19
- 0
modelscope/trainers/default_config.py View File

@@ -1,4 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from modelscope.utils.config import Config

DEFAULT_CONFIG = {
'train': {
'hooks': [{
@@ -12,3 +15,19 @@ DEFAULT_CONFIG = {
}]
}
}


def merge_cfg(cfg: Config):
"""Merge the default config into the input cfg.

This function will pop the default CheckpointHook when the BestCkptSaverHook exists in the input cfg.

@param cfg: The input cfg to be merged into.
"""
cfg.merge_from_dict(DEFAULT_CONFIG, force=False)
# pop duplicate hook

if any(['BestCkptSaverHook' == hook['type'] for hook in cfg.train.hooks]):
cfg.train.hooks = list(
filter(lambda hook: hook['type'] != 'CheckpointHook',
cfg.train.hooks))

+ 2
- 5
modelscope/trainers/trainer.py View File

@@ -41,7 +41,7 @@ from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
init_dist, set_random_seed)
from .base import BaseTrainer
from .builder import TRAINERS
from .default_config import DEFAULT_CONFIG
from .default_config import merge_cfg
from .hooks.hook import Hook
from .parallel.builder import build_parallel
from .parallel.utils import is_parallel
@@ -114,7 +114,7 @@ class EpochBasedTrainer(BaseTrainer):
super().__init__(cfg_file, arg_parse_fn)

# add default config
self.cfg.merge_from_dict(self._get_default_config(), force=False)
merge_cfg(self.cfg)
self.cfg = self.rebuild_config(self.cfg)

if 'cfg_options' in kwargs:
@@ -951,9 +951,6 @@ class EpochBasedTrainer(BaseTrainer):
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)

def _get_default_config(self):
return DEFAULT_CONFIG


def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to


+ 0
- 3
tests/trainers/hooks/test_checkpoint_hook.py View File

@@ -204,9 +204,6 @@ class BestCkptSaverHookTest(unittest.TestCase):
trainer = build_trainer(trainer_name, kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth',
results_files)



Loading…
Cancel
Save