Browse Source

Merge pull request #1030 from Wanglongzhi2001/master

Add set_weights and get_weighst APIs
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
e72024b520
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 3 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  2. +5
    -3
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  3. +59
    -0
      src/TensorFlowNET.Keras/Engine/Layer.cs

+ 3
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Training;

namespace Tensorflow.Keras
@@ -18,6 +19,8 @@ namespace Tensorflow.Keras
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }
List<IVariableV1> Weights { get; set; }
void set_weights(IEnumerable<NDArray> weights);
List<NDArray> get_weights();
Shape OutputShape { get; }
Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }


+ 5
- 3
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Operations;
using Tensorflow.Train;
using Tensorflow.Util;
@@ -71,7 +72,10 @@ namespace Tensorflow

public List<IVariableV1> TrainableVariables => throw new NotImplementedException();
public List<IVariableV1> TrainableWeights => throw new NotImplementedException();
public List<IVariableV1> Weights => throw new NotImplementedException();
public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public List<NDArray> get_weights() => throw new NotImplementedException();
public void set_weights(IEnumerable<NDArray> weights) => throw new NotImplementedException();
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();

public Shape OutputShape => throw new NotImplementedException();
@@ -84,8 +88,6 @@ namespace Tensorflow
protected bool built = false;
public bool Built => built;

List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public RnnCell(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,


+ 59
- 0
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -30,6 +30,9 @@ using Tensorflow.Training;
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Util;
using static Tensorflow.Binding;
using Tensorflow.Framework;
using Tensorflow.Sessions;


namespace Tensorflow.Keras.Engine
{
@@ -134,6 +137,62 @@ namespace Tensorflow.Keras.Engine
}
}

public virtual void set_weights(IEnumerable<NDArray> weights)
{
if (Weights.Count() != weights.Count()) throw new ValueError(
$"You called `set_weights` on layer \"{this.name}\"" +
$"with a weight list of length {len(weights)}, but the layer was " +
$"expecting {len(Weights)} weights.");



// check if the shapes are compatible
var weight_index = 0;
foreach(var w in weights)
{
if (!Weights[weight_index].AsTensor().is_compatible_with(w))
{
throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}");
}
weight_index++;
}

if (tf.executing_eagerly())
{
foreach (var (this_w, v_w) in zip(Weights, weights))
this_w.assign(v_w, read_value: true);
}
else
{
// TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed.

//Tensors assign_ops = new Tensors();
//var feed_dict = new FeedDict();

//Graph g = tf.Graph().as_default();
//foreach (var (this_w, v_w) in zip(Weights, weights))
//{
// var tf_dtype = this_w.dtype;
// var placeholder_shape = v_w.shape;
// var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape);
// var assign_op = this_w.assign(assign_placeholder);
// assign_ops.Add(assign_op);
// feed_dict.Add(assign_placeholder, v_w);
//}
//var sess = tf.Session().as_default();
//sess.run(assign_ops, feed_dict);

//g.Exit();
}
}

public List<NDArray> get_weights()
{
List<NDArray > weights = new List<NDArray>();
weights.AddRange(Weights.ConvertAll(x => x.numpy()));
return weights;
}

protected int id;
public int Id => id;
protected string name;


Loading…
Cancel
Save