Browse Source

Merge pull request #1020 from Wanglongzhi2001/master

Finish EarlyStopping
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
14da379b3c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 5 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  2. +2
    -4
      src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow.Keras
List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }
List<IVariableV1> Weights { get; }
List<IVariableV1> Weights { get; set}
Shape OutputShape { get; }
Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }


+ 2
- 4
src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs View File

@@ -77,7 +77,7 @@ public class EarlyStopping: ICallback
// Restore the weights after first epoch if no progress is ever made.
if (_restore_best_weights && _best_weights == null)
{
_best_weights = _parameters.Model.TrainableWeights;
_best_weights = _parameters.Model.Weights;
}
_wait += 1;

@@ -103,9 +103,7 @@ public class EarlyStopping: ICallback
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
}
}
// Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet.
// TODO(Wanglongzhi2001): implement it.
// _parameters.Model.load_weights(best_weights);
_parameters.Model.Weights = _best_weights;
}
}
public void on_train_end()


Loading…
Cancel
Save