Browse Source

Merge pull request #1021 from Wanglongzhi2001/master

Finish EarlyStopping
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
ad36e37f75
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 2 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  3. +1
    -1
      src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow.Keras
List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }
List<IVariableV1> Weights { get; set}
List<IVariableV1> Weights { get; set; }
Shape OutputShape { get; }
Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }


+ 2
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -84,6 +84,8 @@ namespace Tensorflow
protected bool built = false;
public bool Built => built;

List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public RnnCell(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,


+ 1
- 1
src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs View File

@@ -102,8 +102,8 @@ public class EarlyStopping: ICallback
{
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
}
_parameters.Model.Weights = _best_weights;
}
_parameters.Model.Weights = _best_weights;
}
}
public void on_train_end()


Loading…
Cancel
Save