Browse Source

Merge pull request #1180 from Wanglongzhi2001/master

fix: fix EarlyStopping
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
3811e4e140
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 22 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
  2. +42
    -22
      src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs

+ 6
- 0
src/TensorFlowNET.Core/NumPy/Numpy.Math.cs View File

@@ -85,5 +85,11 @@ namespace Tensorflow.NumPy

[AutoNumPy]
public static NDArray add(NDArray x, NDArray y) => new NDArray(math_ops.add(x, y));

[AutoNumPy]
public static NDArray greater(NDArray x, NDArray y) => new NDArray(tf.greater(x, y));

[AutoNumPy]
public static NDArray less(NDArray x, NDArray y) => new NDArray(tf.less(x, y));
}
}

+ 42
- 22
src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs View File

@@ -19,8 +19,10 @@ public class EarlyStopping: ICallback
string _monitor;
string _mode;
bool _restore_best_weights;
List<IVariableV1>? _best_weights;
List<NDArray>? _best_weights;
CallbackParams _parameters;
Func<NDArray, NDArray, NDArray> _monitor_op;

public Dictionary<string, List<float>>? history { get; set; }
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0,
@@ -38,17 +40,49 @@ public class EarlyStopping: ICallback
_min_delta = Math.Abs(min_delta);
_restore_best_weights = restore_best_weights;
_mode = mode;
if (mode != "auto" && mode != "min" && mode != "max")

if (_mode != "auto" && _mode != "min" && _mode != "max")
{
Console.WriteLine($"EarlyStopping mode {_mode} is unknown, fallback to auto mode.");
_mode = "auto";
}

if (_mode == "min")
{
_monitor_op = np.less;
}
else if (_mode == "max")
{
_monitor_op = np.greater;
}
else
{
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
{
_monitor_op = np.greater;
}
else
{
_monitor_op = np.less;
}
}

if (_monitor_op == np.greater)
{
Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode);
_min_delta *= 1;
}
else
{
_min_delta *= -1;
}
}
public void on_train_begin()
{
_wait = 0;
_stopped_epoch = 0;
_best = _monitor_op == np.less ? (float)np.Inf : (float)-np.Inf;
_best_weights = null;
_best_epoch = 0;
_best = (float)np.Inf;
}

public void on_epoch_begin(int epoch)
@@ -74,7 +108,7 @@ public class EarlyStopping: ICallback
// Restore the weights after first epoch if no progress is ever made.
if (_restore_best_weights && _best_weights == null)
{
_best_weights = _parameters.Model.Weights;
_best_weights = _parameters.Model.get_weights();
}
_wait += 1;

@@ -83,7 +117,7 @@ public class EarlyStopping: ICallback
_best = current;
_best_epoch = epoch;
if (_restore_best_weights)
_best_weights = _parameters.Model.TrainableWeights;
_best_weights = _parameters.Model.get_weights();
// Only restart wait if we beat both the baseline and our previous best.
if (_baseline == 0f || _is_improvement(current, _baseline))
_wait = 0;
@@ -99,7 +133,7 @@ public class EarlyStopping: ICallback
{
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
}
_parameters.Model.Weights = _best_weights;
_parameters.Model.set_weights(_best_weights);
}
}
}
@@ -131,21 +165,7 @@ public class EarlyStopping: ICallback
}
public bool _is_improvement(float monitor_value, float reference_value)
{
bool less_op = (monitor_value - _min_delta) < reference_value;
bool greater_op = (monitor_value - _min_delta) >= reference_value;
if (_mode == "min")
return less_op;
else if (_mode == "max")
return greater_op;
else
{
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
{
return greater_op;
}
else
return less_op;
}
return _monitor_op(monitor_value - _min_delta, reference_value);
}

public void on_test_end(Dictionary<string, float> logs)


Loading…
Cancel
Save