|
|
|
@@ -0,0 +1,109 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from modelscope.metainfo import Hooks |
|
|
|
from modelscope.utils.logger import get_logger |
|
|
|
from .builder import HOOKS |
|
|
|
from .hook import Hook |
|
|
|
from .priority import Priority |
|
|
|
|
|
|
|
|
|
|
|
@HOOKS.register_module(module_name=Hooks.EarlyStopHook) |
|
|
|
class EarlyStopHook(Hook): |
|
|
|
"""Early stop when a specific metric stops improving. |
|
|
|
|
|
|
|
Args: |
|
|
|
metric_key (str): Metric key to be monitored. |
|
|
|
rule (str): Comparison rule for best score. Support "max" and "min". |
|
|
|
If rule is "max", the training will stop when `metric_key` has stopped increaing. |
|
|
|
If rule is "min", the training will stop when `metric_key` has stopped decreasing. |
|
|
|
patience (int): Trainer will stop if the monitored metric did not improve for the last `patience` times. |
|
|
|
min_delta (float): Minimum change in the monitored metric to quailfy as an improvement. |
|
|
|
check_finite (bool): If true, stops training when the metric becomes NaN or infinite. |
|
|
|
by_epoch (int): Saving checkpoints by epoch or by iteration. |
|
|
|
interval (int): The frequency to trigger early stop check. If `by_epoch=True`, |
|
|
|
it means the number of epochs, else means the number of iterations. |
|
|
|
""" |
|
|
|
|
|
|
|
PRIORITY = Priority.VERY_LOW |
|
|
|
rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
metric_key: str, |
|
|
|
rule: str = 'max', |
|
|
|
patience: int = 3, |
|
|
|
min_delta: float = 0.0, |
|
|
|
check_finite: bool = True, |
|
|
|
by_epoch: bool = True, |
|
|
|
interval: int = 1): |
|
|
|
self.metric_key = metric_key |
|
|
|
self.rule = rule |
|
|
|
self.patience = patience |
|
|
|
self.min_delta = min_delta |
|
|
|
self.check_finite = check_finite |
|
|
|
self.by_epoch = by_epoch |
|
|
|
self.interval = interval |
|
|
|
|
|
|
|
self.wait_count = 0 |
|
|
|
self.best_score = float('inf') if rule == 'min' else -float('inf') |
|
|
|
|
|
|
|
def before_run(self, trainer): |
|
|
|
if not hasattr(trainer, 'logger'): |
|
|
|
self.logger = get_logger(__name__) |
|
|
|
else: |
|
|
|
self.logger = trainer.logger |
|
|
|
|
|
|
|
def _should_stop(self, trainer): |
|
|
|
metric_values = trainer.metric_values |
|
|
|
|
|
|
|
if metric_values is None: |
|
|
|
return False |
|
|
|
|
|
|
|
if self.metric_key not in metric_values: |
|
|
|
raise ValueError( |
|
|
|
f'Metric not found: {self.metric_key} not in {metric_values}') |
|
|
|
|
|
|
|
should_stop = False |
|
|
|
current_score = metric_values[self.metric_key] |
|
|
|
if self.check_finite and not np.isfinite(current_score): |
|
|
|
should_stop = True |
|
|
|
self.logger.warn( |
|
|
|
f'Metric {self.metric_key} = {current_score} is not finite. ' |
|
|
|
f'Previous best metric: {self.best_score:.4f}.') |
|
|
|
elif self.rule_map[self.rule](current_score - self.min_delta, |
|
|
|
self.best_score): |
|
|
|
self.best_score = current_score |
|
|
|
self.wait_count = 0 |
|
|
|
else: |
|
|
|
self.wait_count += 1 |
|
|
|
if self.wait_count >= self.patience: |
|
|
|
should_stop = True |
|
|
|
self.logger.info( |
|
|
|
f'Metric {self.metric_key} did not improve in the last {self.wait_count} epochs or iterations. ' |
|
|
|
f'Best score: {self.best_score:.4f}.') |
|
|
|
return should_stop |
|
|
|
|
|
|
|
def _stop_training(self, trainer): |
|
|
|
self.logger.info('Early Stopping!') |
|
|
|
trainer._stop_training = True |
|
|
|
|
|
|
|
def after_train_epoch(self, trainer): |
|
|
|
if not self.by_epoch: |
|
|
|
return |
|
|
|
|
|
|
|
if not self.every_n_epochs(trainer, self.interval): |
|
|
|
return |
|
|
|
|
|
|
|
if self._should_stop(trainer): |
|
|
|
self._stop_training(trainer) |
|
|
|
|
|
|
|
def after_train_iter(self, trainer): |
|
|
|
if self.by_epoch: |
|
|
|
return |
|
|
|
|
|
|
|
if not self.every_n_iters(trainer, self.interval): |
|
|
|
return |
|
|
|
|
|
|
|
if self._should_stop(trainer): |
|
|
|
self._stop_training(trainer) |