|
|
@@ -41,7 +41,7 @@ from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, |
|
|
|
init_dist, set_random_seed) |
|
|
|
from .base import BaseTrainer |
|
|
|
from .builder import TRAINERS |
|
|
|
from .default_config import DEFAULT_CONFIG |
|
|
|
from .default_config import merge_cfg |
|
|
|
from .hooks.hook import Hook |
|
|
|
from .parallel.builder import build_parallel |
|
|
|
from .parallel.utils import is_parallel |
|
|
@@ -114,7 +114,7 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
super().__init__(cfg_file, arg_parse_fn) |
|
|
|
|
|
|
|
# add default config |
|
|
|
self.cfg.merge_from_dict(self._get_default_config(), force=False) |
|
|
|
merge_cfg(self.cfg) |
|
|
|
self.cfg = self.rebuild_config(self.cfg) |
|
|
|
|
|
|
|
if 'cfg_options' in kwargs: |
|
|
@@ -951,9 +951,6 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
stage_hook_infos.append(info) |
|
|
|
return '\n'.join(stage_hook_infos) |
|
|
|
|
|
|
|
def _get_default_config(self): |
|
|
|
return DEFAULT_CONFIG |
|
|
|
|
|
|
|
|
|
|
|
def worker_init_fn(worker_id, num_workers, rank, seed): |
|
|
|
# The seed of each worker equals to |
|
|
|