diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 2b864f90..55409df3 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras List TrainableVariables { get; } List TrainableWeights { get; } List NonTrainableWeights { get; } - List Weights { get; } + List Weights { get; set} Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs index 0aa5006c..cba621fa 100644 --- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -77,7 +77,7 @@ public class EarlyStopping: ICallback // Restore the weights after first epoch if no progress is ever made. if (_restore_best_weights && _best_weights == null) { - _best_weights = _parameters.Model.TrainableWeights; + _best_weights = _parameters.Model.Weights; } _wait += 1; @@ -103,9 +103,7 @@ public class EarlyStopping: ICallback Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); } } - // Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet. - // TODO(Wanglongzhi2001): implement it. - // _parameters.Model.load_weights(best_weights); + _parameters.Model.Weights = _best_weights; } } public void on_train_end()