From 46729714d388ce1ae35c7ff32fcd02a5b20f58ac Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Sun, 9 Apr 2023 16:15:46 +0800 Subject: [PATCH] Finish EarlyStopping --- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 2 +- src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 2b864f90..55409df3 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras List TrainableVariables { get; } List TrainableWeights { get; } List NonTrainableWeights { get; } - List Weights { get; } + List Weights { get; set} Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs index 0aa5006c..cba621fa 100644 --- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -77,7 +77,7 @@ public class EarlyStopping: ICallback // Restore the weights after first epoch if no progress is ever made. if (_restore_best_weights && _best_weights == null) { - _best_weights = _parameters.Model.TrainableWeights; + _best_weights = _parameters.Model.Weights; } _wait += 1; @@ -103,9 +103,7 @@ public class EarlyStopping: ICallback Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); } } - // Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet. - // TODO(Wanglongzhi2001): implement it. - // _parameters.Model.load_weights(best_weights); + _parameters.Model.Weights = _best_weights; } } public void on_train_end()