|
|
@@ -215,6 +215,10 @@ class CheckpointHook(Hook): |
|
|
# TODO a temp fix to avoid pipeline_name and task mismatch |
|
|
# TODO a temp fix to avoid pipeline_name and task mismatch |
|
|
config['pipeline'] = {'type': config['task']} |
|
|
config['pipeline'] = {'type': config['task']} |
|
|
|
|
|
|
|
|
|
|
|
# remove parallel module that is not JSON serializable |
|
|
|
|
|
if 'parallel' in config and 'module' in config['parallel']: |
|
|
|
|
|
del config['parallel']['module'] |
|
|
|
|
|
|
|
|
class SaveConfig: |
|
|
class SaveConfig: |
|
|
|
|
|
|
|
|
def __init__(self, output_dir, config): |
|
|
def __init__(self, output_dir, config): |
|
|
@@ -422,4 +426,5 @@ class BestCkptSaverHook(CheckpointHook): |
|
|
|
|
|
|
|
|
def after_run(self, trainer): |
|
|
def after_run(self, trainer): |
|
|
if self.restore_best: |
|
|
if self.restore_best: |
|
|
self.load_checkpoint(self._best_ckpt_file, trainer) |
|
|
|
|
|
|
|
|
if is_master(): |
|
|
|
|
|
self.load_checkpoint(self._best_ckpt_file, trainer) |