Browse Source

fix: fix EarlyStopping

tags/v0.110.4-Transformer-Model
Wanglongzhi2001 2 years ago
parent
commit
f809f6eace
1 changed files with 42 additions and 22 deletions
  1. +42
    -22
      src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs

+ 42
- 22
src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs View File

@@ -19,8 +19,10 @@ public class EarlyStopping: ICallback
string _monitor; string _monitor;
string _mode; string _mode;
bool _restore_best_weights; bool _restore_best_weights;
List<IVariableV1>? _best_weights;
List<NDArray>? _best_weights;
CallbackParams _parameters; CallbackParams _parameters;
Func<NDArray, NDArray, NDArray> _monitor_op;

public Dictionary<string, List<float>>? history { get; set; } public Dictionary<string, List<float>>? history { get; set; }
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0, public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0,
@@ -38,17 +40,49 @@ public class EarlyStopping: ICallback
_min_delta = Math.Abs(min_delta); _min_delta = Math.Abs(min_delta);
_restore_best_weights = restore_best_weights; _restore_best_weights = restore_best_weights;
_mode = mode; _mode = mode;
if (mode != "auto" && mode != "min" && mode != "max")

if (_mode != "auto" && _mode != "min" && _mode != "max")
{
Console.WriteLine($"EarlyStopping mode {_mode} is unknown, fallback to auto mode.");
_mode = "auto";
}

if (_mode == "min")
{
_monitor_op = np.less;
}
else if (_mode == "max")
{
_monitor_op = np.greater;
}
else
{
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
{
_monitor_op = np.greater;
}
else
{
_monitor_op = np.less;
}
}

if (_monitor_op == np.greater)
{ {
Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode);
_min_delta *= 1;
}
else
{
_min_delta *= -1;
} }
} }
public void on_train_begin() public void on_train_begin()
{ {
_wait = 0; _wait = 0;
_stopped_epoch = 0; _stopped_epoch = 0;
_best = _monitor_op == np.less ? (float)np.Inf : (float)-np.Inf;
_best_weights = null;
_best_epoch = 0; _best_epoch = 0;
_best = (float)np.Inf;
} }


public void on_epoch_begin(int epoch) public void on_epoch_begin(int epoch)
@@ -74,7 +108,7 @@ public class EarlyStopping: ICallback
// Restore the weights after first epoch if no progress is ever made. // Restore the weights after first epoch if no progress is ever made.
if (_restore_best_weights && _best_weights == null) if (_restore_best_weights && _best_weights == null)
{ {
_best_weights = _parameters.Model.Weights;
_best_weights = _parameters.Model.get_weights();
} }
_wait += 1; _wait += 1;


@@ -83,7 +117,7 @@ public class EarlyStopping: ICallback
_best = current; _best = current;
_best_epoch = epoch; _best_epoch = epoch;
if (_restore_best_weights) if (_restore_best_weights)
_best_weights = _parameters.Model.TrainableWeights;
_best_weights = _parameters.Model.get_weights();
// Only restart wait if we beat both the baseline and our previous best. // Only restart wait if we beat both the baseline and our previous best.
if (_baseline == 0f || _is_improvement(current, _baseline)) if (_baseline == 0f || _is_improvement(current, _baseline))
_wait = 0; _wait = 0;
@@ -99,7 +133,7 @@ public class EarlyStopping: ICallback
{ {
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
} }
_parameters.Model.Weights = _best_weights;
_parameters.Model.set_weights(_best_weights);
} }
} }
} }
@@ -131,21 +165,7 @@ public class EarlyStopping: ICallback
} }
public bool _is_improvement(float monitor_value, float reference_value) public bool _is_improvement(float monitor_value, float reference_value)
{ {
bool less_op = (monitor_value - _min_delta) < reference_value;
bool greater_op = (monitor_value - _min_delta) >= reference_value;
if (_mode == "min")
return less_op;
else if (_mode == "max")
return greater_op;
else
{
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
{
return greater_op;
}
else
return less_op;
}
return _monitor_op(monitor_value - _min_delta, reference_value);
} }


public void on_test_end(Dictionary<string, float> logs) public void on_test_end(Dictionary<string, float> logs)


Loading…
Cancel
Save