Browse Source

[to #42322933] Add early stop hook

master^2
pangda 3 years ago
parent
commit
5fd3e7bb43
4 changed files with 114 additions and 0 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +1
    -0
      modelscope/trainers/hooks/__init__.py
  3. +109
    -0
      modelscope/trainers/hooks/early_stop_hook.py
  4. +3
    -0
      modelscope/trainers/trainer.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -522,6 +522,7 @@ class Hooks(object):
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook'

# train
EarlyStopHook = 'EarlyStopHook'
DeepspeedHook = 'DeepspeedHook'




+ 1
- 0
modelscope/trainers/hooks/__init__.py View File

@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .builder import HOOKS, build_hook
from .checkpoint_hook import BestCkptSaverHook, CheckpointHook
from .early_stop_hook import EarlyStopHook
from .compression import SparsityHook
from .evaluation_hook import EvaluationHook
from .hook import Hook


+ 109
- 0
modelscope/trainers/hooks/early_stop_hook.py View File

@@ -0,0 +1,109 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import numpy as np

from modelscope.metainfo import Hooks
from modelscope.utils.logger import get_logger
from .builder import HOOKS
from .hook import Hook
from .priority import Priority


@HOOKS.register_module(module_name=Hooks.EarlyStopHook)
class EarlyStopHook(Hook):
"""Early stop when a specific metric stops improving.

Args:
metric_key (str): Metric key to be monitored.
rule (str): Comparison rule for best score. Support "max" and "min".
If rule is "max", the training will stop when `metric_key` has stopped increaing.
If rule is "min", the training will stop when `metric_key` has stopped decreasing.
patience (int): Trainer will stop if the monitored metric did not improve for the last `patience` times.
min_delta (float): Minimum change in the monitored metric to quailfy as an improvement.
check_finite (bool): If true, stops training when the metric becomes NaN or infinite.
by_epoch (int): Saving checkpoints by epoch or by iteration.
interval (int): The frequency to trigger early stop check. If `by_epoch=True`,
it means the number of epochs, else means the number of iterations.
"""

PRIORITY = Priority.VERY_LOW
rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y}

def __init__(self,
metric_key: str,
rule: str = 'max',
patience: int = 3,
min_delta: float = 0.0,
check_finite: bool = True,
by_epoch: bool = True,
interval: int = 1):
self.metric_key = metric_key
self.rule = rule
self.patience = patience
self.min_delta = min_delta
self.check_finite = check_finite
self.by_epoch = by_epoch
self.interval = interval

self.wait_count = 0
self.best_score = float('inf') if rule == 'min' else -float('inf')

def before_run(self, trainer):
if not hasattr(trainer, 'logger'):
self.logger = get_logger(__name__)
else:
self.logger = trainer.logger

def _should_stop(self, trainer):
metric_values = trainer.metric_values

if metric_values is None:
return False

if self.metric_key not in metric_values:
raise ValueError(
f'Metric not found: {self.metric_key} not in {metric_values}')

should_stop = False
current_score = metric_values[self.metric_key]
if self.check_finite and not np.isfinite(current_score):
should_stop = True
self.logger.warn(
f'Metric {self.metric_key} = {current_score} is not finite. '
f'Previous best metric: {self.best_score:.4f}.')
elif self.rule_map[self.rule](current_score - self.min_delta,
self.best_score):
self.best_score = current_score
self.wait_count = 0
else:
self.wait_count += 1
if self.wait_count >= self.patience:
should_stop = True
self.logger.info(
f'Metric {self.metric_key} did not improve in the last {self.wait_count} epochs or iterations. '
f'Best score: {self.best_score:.4f}.')
return should_stop

def _stop_training(self, trainer):
self.logger.info('Early Stopping!')
trainer._stop_training = True

def after_train_epoch(self, trainer):
if not self.by_epoch:
return

if not self.every_n_epochs(trainer, self.interval):
return

if self._should_stop(trainer):
self._stop_training(trainer)

def after_train_iter(self, trainer):
if self.by_epoch:
return

if not self.every_n_iters(trainer, self.interval):
return

if self._should_stop(trainer):
self._stop_training(trainer)

+ 3
- 0
modelscope/trainers/trainer.py View File

@@ -112,6 +112,7 @@ class EpochBasedTrainer(BaseTrainer):
self._epoch = 0
self._iter = 0
self._inner_iter = 0
self._stop_training = False

if isinstance(model, str):
self.model_dir = self.get_or_download_model_dir(
@@ -910,6 +911,8 @@ class EpochBasedTrainer(BaseTrainer):
# Value changed after the hooks are invoked, do not move them above the invoke_hook code.
self._inner_iter = 0
self._epoch += 1
if self._stop_training:
break

self.invoke_hook(TrainerStages.after_run)



Loading…
Cancel
Save