Browse Source

Merge pull request #1007 from Wanglongzhi2001/master

Add EarlyStopping callback
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
d639ce3b7f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 252 additions and 7 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
  2. +4
    -0
      src/TensorFlowNET.Core/Keras/Engine/IModel.cs
  3. +3
    -2
      src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
  4. +155
    -0
      src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
  5. +1
    -0
      src/TensorFlowNET.Keras/Callbacks/History.cs
  6. +1
    -0
      src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
  7. +16
    -5
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  8. +6
    -0
      src/TensorFlowNET.Keras/Engine/Model.cs
  9. +65
    -0
      test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs

+ 1
- 0
src/TensorFlowNET.Core/Keras/Engine/ICallback.cs View File

@@ -4,6 +4,7 @@ public interface ICallback
{
Dictionary<string, List<float>> history { get; set; }
void on_train_begin();
void on_train_end();
void on_epoch_begin(int epoch);
void on_train_batch_begin(long step);
void on_train_batch_end(long end_step, Dictionary<string, float> logs);


+ 4
- 0
src/TensorFlowNET.Core/Keras/Engine/IModel.cs View File

@@ -17,6 +17,7 @@ public interface IModel : ILayer
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
@@ -28,6 +29,7 @@ public interface IModel : ILayer
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
@@ -73,4 +75,6 @@ public interface IModel : ILayer
void summary(int line_length = -1, float[] positions = null);

IKerasConfig get_config();

void set_stopTraining_true();
}

+ 3
- 2
src/TensorFlowNET.Keras/Callbacks/CallbackList.cs View File

@@ -7,7 +7,8 @@ namespace Tensorflow.Keras.Callbacks;

public class CallbackList
{
List<ICallback> callbacks = new List<ICallback>();
// 改成public使得新定义的callback可以加入到callbacks里
public List<ICallback> callbacks = new List<ICallback>();
public History History => callbacks[0] as History;

public CallbackList(CallbackParams parameters)
@@ -66,7 +67,7 @@ public class CallbackList

public void on_test_batch_begin(long step)
{
callbacks.ForEach(x => x.on_train_batch_begin(step));
callbacks.ForEach(x => x.on_test_batch_begin(step));
}
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
{


+ 155
- 0
src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs View File

@@ -0,0 +1,155 @@
using Tensorflow.Keras.Engine;
namespace Tensorflow.Keras.Callbacks;


/// <summary>
/// Stop training when a monitored metric has stopped improving.
/// </summary>
/// <param name="parameters"></param>
/// <param name="monitor"></param>

public class EarlyStopping: ICallback
{
int _paitence;
int _min_delta;
int _verbose;
int _stopped_epoch;
int _wait;
int _best_epoch;
int _start_from_epoch;
float _best;
float _baseline;
string _monitor;
string _mode;
bool _restore_best_weights;
List<IVariableV1>? _best_weights;
CallbackParams _parameters;
public Dictionary<string, List<float>>? history { get; set; }
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0,
int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false,
int start_from_epoch = 0)
{
_parameters = parameters;
_stopped_epoch = 0;
_wait = 0;
_monitor = monitor;
_paitence = patience;
_verbose = verbose;
_baseline = baseline;
_start_from_epoch = start_from_epoch;
_min_delta = Math.Abs(min_delta);
_restore_best_weights = restore_best_weights;
_mode = mode;
if (mode != "auto" && mode != "min" && mode != "max")
{
Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode);
}
}
public void on_train_begin()
{
_wait = 0;
_stopped_epoch = 0;
_best_epoch = 0;
_best = (float)np.Inf;
}

public void on_epoch_begin(int epoch)
{

}

public void on_train_batch_begin(long step)
{

}

public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
{
}

public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
{
var current = get_monitor_value(epoch_logs);
// If no monitor value exists or still in initial warm-up stage.
if (current == 0f || epoch < _start_from_epoch)
return;
// Restore the weights after first epoch if no progress is ever made.
if (_restore_best_weights && _best_weights == null)
{
_best_weights = _parameters.Model.TrainableWeights;
}
_wait += 1;

if (_is_improvement(current, _best))
{
_best = current;
_best_epoch = epoch;
if (_restore_best_weights)
_best_weights = _parameters.Model.TrainableWeights;
// Only restart wait if we beat both the baseline and our previous best.
if (_baseline == 0f || _is_improvement(current, _baseline))
_wait = 0;
}
// Only check after the first epoch.
if (_wait >= _paitence && epoch > 0)
{
_stopped_epoch = epoch;
_parameters.Model.set_stopTraining_true();
if (_restore_best_weights && _best_weights != null)
{
if (_verbose > 0)
{
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
}
}
// Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet.
// TODO(Wanglongzhi2001): implement it.
// _parameters.Model.load_weights(best_weights);
}
}
public void on_train_end()
{
if (_stopped_epoch > 0 && _verbose > 0)
{
Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping");
}
}
public void on_predict_begin() { }
public void on_predict_batch_begin(long step) { }
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) { }
public void on_predict_end() { }
public void on_test_begin() { }
public void on_test_batch_begin(long step) { }
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { }

float get_monitor_value(Dictionary<string, float> logs)
{
logs = logs ?? new Dictionary<string, float>();
float monitor_value = logs[_monitor];
if (monitor_value == 0f)
{
Console.WriteLine($"Early stopping conditioned on metric {_monitor} " +
$"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}");
}
return monitor_value;
}
public bool _is_improvement(float monitor_value, float reference_value)
{
bool less_op = (monitor_value - _min_delta) < reference_value;
bool greater_op = (monitor_value - _min_delta) >= reference_value;
if (_mode == "min")
return less_op;
else if (_mode == "max")
return greater_op;
else
{
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
{
return greater_op;
}
else
return less_op;
}
}
}

+ 1
- 0
src/TensorFlowNET.Keras/Callbacks/History.cs View File

@@ -23,6 +23,7 @@ public class History : ICallback
epochs = new List<int>();
history = new Dictionary<string, List<float>>();
}
public void on_train_end() { }
public void on_epoch_begin(int epoch)
{



+ 1
- 0
src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs View File

@@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Callbacks
_called_in_fit = true;
_sw = new Stopwatch();
}
public void on_train_end() { }
public void on_test_begin()
{
_sw = new Stopwatch();


+ 16
- 5
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -19,6 +19,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="y"></param>
/// <param name="batch_size"></param>
/// <param name="epochs"></param>
/// <param name="callbacks"></param>
/// <param name="verbose"></param>
/// <param name="validation_split"></param>
/// <param name="shuffle"></param>
@@ -26,6 +27,7 @@ namespace Tensorflow.Keras.Engine
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
@@ -59,7 +61,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

return FitInternal(data_handler, epochs, verbose, validation_data: null,
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
train_step_func: train_step_function);
}

@@ -67,6 +69,7 @@ namespace Tensorflow.Keras.Engine
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
@@ -107,12 +110,12 @@ namespace Tensorflow.Keras.Engine
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
return FitInternal(data_handler, epochs, verbose, validation_data: null,
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
train_step_func: train_step_multi_inputs_function);
}
else
{
return FitInternal(data_handler, epochs, verbose, validation_data: null,
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
train_step_func: train_step_function);
}
}
@@ -122,6 +125,7 @@ namespace Tensorflow.Keras.Engine
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
@@ -143,11 +147,11 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

return FitInternal(data_handler, epochs, verbose, validation_data: validation_data,
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}

History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data,
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
{
stop_training = false;
@@ -159,6 +163,13 @@ namespace Tensorflow.Keras.Engine
Epochs = epochs,
Steps = data_handler.Inferredsteps
});
if (callbackList != null)
{
foreach(var callback in callbackList)
callbacks.callbacks.add(callback);
}
callbacks.on_train_begin();

foreach (var (epoch, iterator) in data_handler.enumerate_epochs())


+ 6
- 0
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -144,5 +144,11 @@ namespace Tensorflow.Keras.Engine
var children = base._trackable_children(save_type, cache);
return children;
}


void IModel.set_stopTraining_true()
{
stop_training = true;
}
}
}

+ 65
- 0
test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs View File

@@ -0,0 +1,65 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.Keras.UnitTest.Helpers;
using static Tensorflow.Binding;
using Tensorflow;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Keras.Engine;
using System.Collections.Generic;
using static Tensorflow.KerasApi;
using Tensorflow.Keras;


namespace TensorFlowNET.Keras.UnitTest
{
[TestClass]
public class EarltstoppingTest
{
[TestMethod]
// Because loading the weight variable into the model has not yet been implemented,
// so you'd better not set patience too large, because the weights will equal to the last epoch's weights.
public void Earltstopping()
{
var layers = keras.layers;
var model = keras.Sequential(new List<ILayer>
{
layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)),
layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation: keras.activations.Relu),
layers.Dense(10)
});

model.summary();

model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
metrics: new[] { "acc" });

var num_epochs = 3;
var batch_size = 8;

var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
x_train = x_train / 255.0f;
// define a CallbackParams first, the parameters you pass al least contain Model and Epochs.
CallbackParams callback_parameters = new CallbackParams
{
Model = model,
Epochs = num_epochs,
};
// define your earlystop
ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy");
// define a callbcaklist, then add the earlystopping to it.
var callbacks = new List<ICallback>();
callbacks.add(earlystop);

model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks);
}

}


}


Loading…
Cancel
Save