From 5fd3e7bb43a37048ffae9ad5229934924d6625cd Mon Sep 17 00:00:00 2001 From: pangda Date: Tue, 6 Dec 2022 10:54:47 +0800 Subject: [PATCH] [to #42322933] Add early stop hook --- modelscope/metainfo.py | 1 + modelscope/trainers/hooks/__init__.py | 1 + modelscope/trainers/hooks/early_stop_hook.py | 109 +++++++++++++++++++ modelscope/trainers/trainer.py | 3 + 4 files changed, 114 insertions(+) create mode 100644 modelscope/trainers/hooks/early_stop_hook.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 663069df..f9c9f2fb 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -522,6 +522,7 @@ class Hooks(object): ClipClampLogitScaleHook = 'ClipClampLogitScaleHook' # train + EarlyStopHook = 'EarlyStopHook' DeepspeedHook = 'DeepspeedHook' diff --git a/modelscope/trainers/hooks/__init__.py b/modelscope/trainers/hooks/__init__.py index c7bd93aa..11a73f24 100644 --- a/modelscope/trainers/hooks/__init__.py +++ b/modelscope/trainers/hooks/__init__.py @@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .builder import HOOKS, build_hook from .checkpoint_hook import BestCkptSaverHook, CheckpointHook + from .early_stop_hook import EarlyStopHook from .compression import SparsityHook from .evaluation_hook import EvaluationHook from .hook import Hook diff --git a/modelscope/trainers/hooks/early_stop_hook.py b/modelscope/trainers/hooks/early_stop_hook.py new file mode 100644 index 00000000..765d94f8 --- /dev/null +++ b/modelscope/trainers/hooks/early_stop_hook.py @@ -0,0 +1,109 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np + +from modelscope.metainfo import Hooks +from modelscope.utils.logger import get_logger +from .builder import HOOKS +from .hook import Hook +from .priority import Priority + + +@HOOKS.register_module(module_name=Hooks.EarlyStopHook) +class EarlyStopHook(Hook): + """Early stop when a specific metric stops improving. + + Args: + metric_key (str): Metric key to be monitored. + rule (str): Comparison rule for best score. Support "max" and "min". + If rule is "max", the training will stop when `metric_key` has stopped increaing. + If rule is "min", the training will stop when `metric_key` has stopped decreasing. + patience (int): Trainer will stop if the monitored metric did not improve for the last `patience` times. + min_delta (float): Minimum change in the monitored metric to quailfy as an improvement. + check_finite (bool): If true, stops training when the metric becomes NaN or infinite. + by_epoch (int): Saving checkpoints by epoch or by iteration. + interval (int): The frequency to trigger early stop check. If `by_epoch=True`, + it means the number of epochs, else means the number of iterations. + """ + + PRIORITY = Priority.VERY_LOW + rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} + + def __init__(self, + metric_key: str, + rule: str = 'max', + patience: int = 3, + min_delta: float = 0.0, + check_finite: bool = True, + by_epoch: bool = True, + interval: int = 1): + self.metric_key = metric_key + self.rule = rule + self.patience = patience + self.min_delta = min_delta + self.check_finite = check_finite + self.by_epoch = by_epoch + self.interval = interval + + self.wait_count = 0 + self.best_score = float('inf') if rule == 'min' else -float('inf') + + def before_run(self, trainer): + if not hasattr(trainer, 'logger'): + self.logger = get_logger(__name__) + else: + self.logger = trainer.logger + + def _should_stop(self, trainer): + metric_values = trainer.metric_values + + if metric_values is None: + return False + + if self.metric_key not in metric_values: + raise ValueError( + f'Metric not found: {self.metric_key} not in {metric_values}') + + should_stop = False + current_score = metric_values[self.metric_key] + if self.check_finite and not np.isfinite(current_score): + should_stop = True + self.logger.warn( + f'Metric {self.metric_key} = {current_score} is not finite. ' + f'Previous best metric: {self.best_score:.4f}.') + elif self.rule_map[self.rule](current_score - self.min_delta, + self.best_score): + self.best_score = current_score + self.wait_count = 0 + else: + self.wait_count += 1 + if self.wait_count >= self.patience: + should_stop = True + self.logger.info( + f'Metric {self.metric_key} did not improve in the last {self.wait_count} epochs or iterations. ' + f'Best score: {self.best_score:.4f}.') + return should_stop + + def _stop_training(self, trainer): + self.logger.info('Early Stopping!') + trainer._stop_training = True + + def after_train_epoch(self, trainer): + if not self.by_epoch: + return + + if not self.every_n_epochs(trainer, self.interval): + return + + if self._should_stop(trainer): + self._stop_training(trainer) + + def after_train_iter(self, trainer): + if self.by_epoch: + return + + if not self.every_n_iters(trainer, self.interval): + return + + if self._should_stop(trainer): + self._stop_training(trainer) diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index e70ad2b4..df2dc25f 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -112,6 +112,7 @@ class EpochBasedTrainer(BaseTrainer): self._epoch = 0 self._iter = 0 self._inner_iter = 0 + self._stop_training = False if isinstance(model, str): self.model_dir = self.get_or_download_model_dir( @@ -910,6 +911,8 @@ class EpochBasedTrainer(BaseTrainer): # Value changed after the hooks are invoked, do not move them above the invoke_hook code. self._inner_iter = 0 self._epoch += 1 + if self._stop_training: + break self.invoke_hook(TrainerStages.after_run)