Browse Source

Add api set_weights and get_weights

tags/v0.100.5-BERT-load
Wanglongzhi2001 2 years ago
parent
commit
78bd4c758e
3 changed files with 32 additions and 3 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  2. +5
    -3
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  3. +24
    -0
      src/TensorFlowNET.Keras/Engine/Layer.cs

+ 3
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Training;

namespace Tensorflow.Keras
@@ -18,6 +19,8 @@ namespace Tensorflow.Keras
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }
List<IVariableV1> Weights { get; set; }
void set_weights(List<NDArray> weights);
List<NDArray> get_weights();
Shape OutputShape { get; }
Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }


+ 5
- 3
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Operations;
using Tensorflow.Train;
using Tensorflow.Util;
@@ -71,7 +72,10 @@ namespace Tensorflow

public List<IVariableV1> TrainableVariables => throw new NotImplementedException();
public List<IVariableV1> TrainableWeights => throw new NotImplementedException();
public List<IVariableV1> Weights => throw new NotImplementedException();
public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public List<NDArray> get_weights() => throw new NotImplementedException();
public void set_weights(List<NDArray> weights) => throw new NotImplementedException();
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();

public Shape OutputShape => throw new NotImplementedException();
@@ -84,8 +88,6 @@ namespace Tensorflow
protected bool built = false;
public bool Built => built;

List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public RnnCell(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,


+ 24
- 0
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -120,6 +120,30 @@ namespace Tensorflow.Keras.Engine
}
}

public virtual void set_weights(List<NDArray> weights)
{
if (Weights.Count() != weights.Count()) throw new ValueError(
$"You called `set_weights` on layer \"{this.name}\"" +
$"with a weight list of length {len(weights)}, but the layer was " +
$"expecting {len(Weights)} weights.");
for (int i = 0; i < weights.Count(); i++)
{
if (weights[i].shape != Weights[i].shape)
{
throw new ValueError($"Layer weight shape {weights[i].shape} not compatible with provided weight shape {Weights[i].shape}");
}
}
foreach (var (this_w, v_w) in zip(Weights, weights))
this_w.assign(v_w, read_value: true);
}

public List<NDArray> get_weights()
{
List<NDArray > weights = new List<NDArray>();
weights.AddRange(Weights.ConvertAll(x => x.numpy()));
return weights;
}

protected int id;
public int Id => id;
protected string name;


Loading…
Cancel
Save