Add EarlyStopping callbacktags/v0.100.5-BERT-load
@@ -4,6 +4,7 @@ public interface ICallback | |||
{ | |||
Dictionary<string, List<float>> history { get; set; } | |||
void on_train_begin(); | |||
void on_train_end(); | |||
void on_epoch_begin(int epoch); | |||
void on_train_batch_begin(long step); | |||
void on_train_batch_end(long end_step, Dictionary<string, float> logs); | |||
@@ -17,6 +17,7 @@ public interface IModel : ILayer | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
@@ -28,6 +29,7 @@ public interface IModel : ILayer | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
@@ -73,4 +75,6 @@ public interface IModel : ILayer | |||
void summary(int line_length = -1, float[] positions = null); | |||
IKerasConfig get_config(); | |||
void set_stopTraining_true(); | |||
} |
@@ -7,7 +7,8 @@ namespace Tensorflow.Keras.Callbacks; | |||
public class CallbackList | |||
{ | |||
List<ICallback> callbacks = new List<ICallback>(); | |||
// 改成public使得新定义的callback可以加入到callbacks里 | |||
public List<ICallback> callbacks = new List<ICallback>(); | |||
public History History => callbacks[0] as History; | |||
public CallbackList(CallbackParams parameters) | |||
@@ -66,7 +67,7 @@ public class CallbackList | |||
public void on_test_batch_begin(long step) | |||
{ | |||
callbacks.ForEach(x => x.on_train_batch_begin(step)); | |||
callbacks.ForEach(x => x.on_test_batch_begin(step)); | |||
} | |||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||
{ | |||
@@ -0,0 +1,155 @@ | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Callbacks; | |||
/// <summary> | |||
/// Stop training when a monitored metric has stopped improving. | |||
/// </summary> | |||
/// <param name="parameters"></param> | |||
/// <param name="monitor"></param> | |||
public class EarlyStopping: ICallback | |||
{ | |||
int _paitence; | |||
int _min_delta; | |||
int _verbose; | |||
int _stopped_epoch; | |||
int _wait; | |||
int _best_epoch; | |||
int _start_from_epoch; | |||
float _best; | |||
float _baseline; | |||
string _monitor; | |||
string _mode; | |||
bool _restore_best_weights; | |||
List<IVariableV1>? _best_weights; | |||
CallbackParams _parameters; | |||
public Dictionary<string, List<float>>? history { get; set; } | |||
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model | |||
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0, | |||
int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false, | |||
int start_from_epoch = 0) | |||
{ | |||
_parameters = parameters; | |||
_stopped_epoch = 0; | |||
_wait = 0; | |||
_monitor = monitor; | |||
_paitence = patience; | |||
_verbose = verbose; | |||
_baseline = baseline; | |||
_start_from_epoch = start_from_epoch; | |||
_min_delta = Math.Abs(min_delta); | |||
_restore_best_weights = restore_best_weights; | |||
_mode = mode; | |||
if (mode != "auto" && mode != "min" && mode != "max") | |||
{ | |||
Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode); | |||
} | |||
} | |||
public void on_train_begin() | |||
{ | |||
_wait = 0; | |||
_stopped_epoch = 0; | |||
_best_epoch = 0; | |||
_best = (float)np.Inf; | |||
} | |||
public void on_epoch_begin(int epoch) | |||
{ | |||
} | |||
public void on_train_batch_begin(long step) | |||
{ | |||
} | |||
public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | |||
{ | |||
} | |||
public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs) | |||
{ | |||
var current = get_monitor_value(epoch_logs); | |||
// If no monitor value exists or still in initial warm-up stage. | |||
if (current == 0f || epoch < _start_from_epoch) | |||
return; | |||
// Restore the weights after first epoch if no progress is ever made. | |||
if (_restore_best_weights && _best_weights == null) | |||
{ | |||
_best_weights = _parameters.Model.TrainableWeights; | |||
} | |||
_wait += 1; | |||
if (_is_improvement(current, _best)) | |||
{ | |||
_best = current; | |||
_best_epoch = epoch; | |||
if (_restore_best_weights) | |||
_best_weights = _parameters.Model.TrainableWeights; | |||
// Only restart wait if we beat both the baseline and our previous best. | |||
if (_baseline == 0f || _is_improvement(current, _baseline)) | |||
_wait = 0; | |||
} | |||
// Only check after the first epoch. | |||
if (_wait >= _paitence && epoch > 0) | |||
{ | |||
_stopped_epoch = epoch; | |||
_parameters.Model.set_stopTraining_true(); | |||
if (_restore_best_weights && _best_weights != null) | |||
{ | |||
if (_verbose > 0) | |||
{ | |||
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); | |||
} | |||
} | |||
// Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet. | |||
// TODO(Wanglongzhi2001): implement it. | |||
// _parameters.Model.load_weights(best_weights); | |||
} | |||
} | |||
public void on_train_end() | |||
{ | |||
if (_stopped_epoch > 0 && _verbose > 0) | |||
{ | |||
Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping"); | |||
} | |||
} | |||
public void on_predict_begin() { } | |||
public void on_predict_batch_begin(long step) { } | |||
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) { } | |||
public void on_predict_end() { } | |||
public void on_test_begin() { } | |||
public void on_test_batch_begin(long step) { } | |||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { } | |||
float get_monitor_value(Dictionary<string, float> logs) | |||
{ | |||
logs = logs ?? new Dictionary<string, float>(); | |||
float monitor_value = logs[_monitor]; | |||
if (monitor_value == 0f) | |||
{ | |||
Console.WriteLine($"Early stopping conditioned on metric {_monitor} " + | |||
$"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}"); | |||
} | |||
return monitor_value; | |||
} | |||
public bool _is_improvement(float monitor_value, float reference_value) | |||
{ | |||
bool less_op = (monitor_value - _min_delta) < reference_value; | |||
bool greater_op = (monitor_value - _min_delta) >= reference_value; | |||
if (_mode == "min") | |||
return less_op; | |||
else if (_mode == "max") | |||
return greater_op; | |||
else | |||
{ | |||
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) | |||
{ | |||
return greater_op; | |||
} | |||
else | |||
return less_op; | |||
} | |||
} | |||
} |
@@ -23,6 +23,7 @@ public class History : ICallback | |||
epochs = new List<int>(); | |||
history = new Dictionary<string, List<float>>(); | |||
} | |||
public void on_train_end() { } | |||
public void on_epoch_begin(int epoch) | |||
{ | |||
@@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Callbacks | |||
_called_in_fit = true; | |||
_sw = new Stopwatch(); | |||
} | |||
public void on_train_end() { } | |||
public void on_test_begin() | |||
{ | |||
_sw = new Stopwatch(); | |||
@@ -19,6 +19,7 @@ namespace Tensorflow.Keras.Engine | |||
/// <param name="y"></param> | |||
/// <param name="batch_size"></param> | |||
/// <param name="epochs"></param> | |||
/// <param name="callbacks"></param> | |||
/// <param name="verbose"></param> | |||
/// <param name="validation_split"></param> | |||
/// <param name="shuffle"></param> | |||
@@ -26,6 +27,7 @@ namespace Tensorflow.Keras.Engine | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
@@ -59,7 +61,7 @@ namespace Tensorflow.Keras.Engine | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||
train_step_func: train_step_function); | |||
} | |||
@@ -67,6 +69,7 @@ namespace Tensorflow.Keras.Engine | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
@@ -107,12 +110,12 @@ namespace Tensorflow.Keras.Engine | |||
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | |||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | |||
{ | |||
return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||
train_step_func: train_step_multi_inputs_function); | |||
} | |||
else | |||
{ | |||
return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||
train_step_func: train_step_function); | |||
} | |||
} | |||
@@ -122,6 +125,7 @@ namespace Tensorflow.Keras.Engine | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
@@ -143,11 +147,11 @@ namespace Tensorflow.Keras.Engine | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
return FitInternal(data_handler, epochs, verbose, validation_data: validation_data, | |||
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | |||
train_step_func: train_step_function); | |||
} | |||
History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data, | |||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, | |||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
{ | |||
stop_training = false; | |||
@@ -159,6 +163,13 @@ namespace Tensorflow.Keras.Engine | |||
Epochs = epochs, | |||
Steps = data_handler.Inferredsteps | |||
}); | |||
if (callbackList != null) | |||
{ | |||
foreach(var callback in callbackList) | |||
callbacks.callbacks.add(callback); | |||
} | |||
callbacks.on_train_begin(); | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
@@ -144,5 +144,11 @@ namespace Tensorflow.Keras.Engine | |||
var children = base._trackable_children(save_type, cache); | |||
return children; | |||
} | |||
void IModel.set_stopTraining_true() | |||
{ | |||
stop_training = true; | |||
} | |||
} | |||
} |
@@ -0,0 +1,65 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow.Keras.UnitTest.Helpers; | |||
using static Tensorflow.Binding; | |||
using Tensorflow; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Keras.Callbacks; | |||
using Tensorflow.Keras.Engine; | |||
using System.Collections.Generic; | |||
using static Tensorflow.KerasApi; | |||
using Tensorflow.Keras; | |||
namespace TensorFlowNET.Keras.UnitTest | |||
{ | |||
[TestClass] | |||
public class EarltstoppingTest | |||
{ | |||
[TestMethod] | |||
// Because loading the weight variable into the model has not yet been implemented, | |||
// so you'd better not set patience too large, because the weights will equal to the last epoch's weights. | |||
public void Earltstopping() | |||
{ | |||
var layers = keras.layers; | |||
var model = keras.Sequential(new List<ILayer> | |||
{ | |||
layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)), | |||
layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), | |||
layers.MaxPooling2D(), | |||
layers.Flatten(), | |||
layers.Dense(128, activation: keras.activations.Relu), | |||
layers.Dense(10) | |||
}); | |||
model.summary(); | |||
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), | |||
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), | |||
metrics: new[] { "acc" }); | |||
var num_epochs = 3; | |||
var batch_size = 8; | |||
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); | |||
x_train = x_train / 255.0f; | |||
// define a CallbackParams first, the parameters you pass al least contain Model and Epochs. | |||
CallbackParams callback_parameters = new CallbackParams | |||
{ | |||
Model = model, | |||
Epochs = num_epochs, | |||
}; | |||
// define your earlystop | |||
ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); | |||
// define a callbcaklist, then add the earlystopping to it. | |||
var callbacks = new List<ICallback>(); | |||
callbacks.add(earlystop); | |||
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks); | |||
} | |||
} | |||
} | |||