|
@@ -19,8 +19,10 @@ public class EarlyStopping: ICallback |
|
|
string _monitor; |
|
|
string _monitor; |
|
|
string _mode; |
|
|
string _mode; |
|
|
bool _restore_best_weights; |
|
|
bool _restore_best_weights; |
|
|
List<IVariableV1>? _best_weights; |
|
|
|
|
|
|
|
|
List<NDArray>? _best_weights; |
|
|
CallbackParams _parameters; |
|
|
CallbackParams _parameters; |
|
|
|
|
|
Func<NDArray, NDArray, NDArray> _monitor_op; |
|
|
|
|
|
|
|
|
public Dictionary<string, List<float>>? history { get; set; } |
|
|
public Dictionary<string, List<float>>? history { get; set; } |
|
|
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model |
|
|
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model |
|
|
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0, |
|
|
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0, |
|
@@ -38,17 +40,49 @@ public class EarlyStopping: ICallback |
|
|
_min_delta = Math.Abs(min_delta); |
|
|
_min_delta = Math.Abs(min_delta); |
|
|
_restore_best_weights = restore_best_weights; |
|
|
_restore_best_weights = restore_best_weights; |
|
|
_mode = mode; |
|
|
_mode = mode; |
|
|
if (mode != "auto" && mode != "min" && mode != "max") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (_mode != "auto" && _mode != "min" && _mode != "max") |
|
|
|
|
|
{ |
|
|
|
|
|
Console.WriteLine($"EarlyStopping mode {_mode} is unknown, fallback to auto mode."); |
|
|
|
|
|
_mode = "auto"; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (_mode == "min") |
|
|
|
|
|
{ |
|
|
|
|
|
_monitor_op = np.less; |
|
|
|
|
|
} |
|
|
|
|
|
else if (_mode == "max") |
|
|
|
|
|
{ |
|
|
|
|
|
_monitor_op = np.greater; |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) |
|
|
|
|
|
{ |
|
|
|
|
|
_monitor_op = np.greater; |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
_monitor_op = np.less; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (_monitor_op == np.greater) |
|
|
{ |
|
|
{ |
|
|
Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode); |
|
|
|
|
|
|
|
|
_min_delta *= 1; |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
_min_delta *= -1; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
public void on_train_begin() |
|
|
public void on_train_begin() |
|
|
{ |
|
|
{ |
|
|
_wait = 0; |
|
|
_wait = 0; |
|
|
_stopped_epoch = 0; |
|
|
_stopped_epoch = 0; |
|
|
|
|
|
_best = _monitor_op == np.less ? (float)np.Inf : (float)-np.Inf; |
|
|
|
|
|
_best_weights = null; |
|
|
_best_epoch = 0; |
|
|
_best_epoch = 0; |
|
|
_best = (float)np.Inf; |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
public void on_epoch_begin(int epoch) |
|
|
public void on_epoch_begin(int epoch) |
|
@@ -74,7 +108,7 @@ public class EarlyStopping: ICallback |
|
|
// Restore the weights after first epoch if no progress is ever made. |
|
|
// Restore the weights after first epoch if no progress is ever made. |
|
|
if (_restore_best_weights && _best_weights == null) |
|
|
if (_restore_best_weights && _best_weights == null) |
|
|
{ |
|
|
{ |
|
|
_best_weights = _parameters.Model.Weights; |
|
|
|
|
|
|
|
|
_best_weights = _parameters.Model.get_weights(); |
|
|
} |
|
|
} |
|
|
_wait += 1; |
|
|
_wait += 1; |
|
|
|
|
|
|
|
@@ -83,7 +117,7 @@ public class EarlyStopping: ICallback |
|
|
_best = current; |
|
|
_best = current; |
|
|
_best_epoch = epoch; |
|
|
_best_epoch = epoch; |
|
|
if (_restore_best_weights) |
|
|
if (_restore_best_weights) |
|
|
_best_weights = _parameters.Model.TrainableWeights; |
|
|
|
|
|
|
|
|
_best_weights = _parameters.Model.get_weights(); |
|
|
// Only restart wait if we beat both the baseline and our previous best. |
|
|
// Only restart wait if we beat both the baseline and our previous best. |
|
|
if (_baseline == 0f || _is_improvement(current, _baseline)) |
|
|
if (_baseline == 0f || _is_improvement(current, _baseline)) |
|
|
_wait = 0; |
|
|
_wait = 0; |
|
@@ -99,7 +133,7 @@ public class EarlyStopping: ICallback |
|
|
{ |
|
|
{ |
|
|
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); |
|
|
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); |
|
|
} |
|
|
} |
|
|
_parameters.Model.Weights = _best_weights; |
|
|
|
|
|
|
|
|
_parameters.Model.set_weights(_best_weights); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
@@ -131,21 +165,7 @@ public class EarlyStopping: ICallback |
|
|
} |
|
|
} |
|
|
public bool _is_improvement(float monitor_value, float reference_value) |
|
|
public bool _is_improvement(float monitor_value, float reference_value) |
|
|
{ |
|
|
{ |
|
|
bool less_op = (monitor_value - _min_delta) < reference_value; |
|
|
|
|
|
bool greater_op = (monitor_value - _min_delta) >= reference_value; |
|
|
|
|
|
if (_mode == "min") |
|
|
|
|
|
return less_op; |
|
|
|
|
|
else if (_mode == "max") |
|
|
|
|
|
return greater_op; |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) |
|
|
|
|
|
{ |
|
|
|
|
|
return greater_op; |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
return less_op; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
return _monitor_op(monitor_value - _min_delta, reference_value); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
public void on_test_end(Dictionary<string, float> logs) |
|
|
public void on_test_end(Dictionary<string, float> logs) |
|
|