From 78bd4c758e1e4f2cf2025c6a630cbcf6e419f709 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Thu, 20 Apr 2023 10:49:02 +0800 Subject: [PATCH 1/2] Add api set_weights and get_weights --- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 3 +++ .../Operations/NnOps/RNNCell.cs | 8 ++++--- src/TensorFlowNET.Keras/Engine/Layer.cs | 24 +++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 9d69d5d0..f16d54d1 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,5 +1,6 @@ using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; using Tensorflow.Training; namespace Tensorflow.Keras @@ -18,6 +19,8 @@ namespace Tensorflow.Keras List TrainableWeights { get; } List NonTrainableWeights { get; } List Weights { get; set; } + void set_weights(List weights); + List get_weights(); Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index bc4daf13..93e0edf0 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; using Tensorflow.Operations; using Tensorflow.Train; using Tensorflow.Util; @@ -71,7 +72,10 @@ namespace Tensorflow public List TrainableVariables => throw new NotImplementedException(); public List TrainableWeights => throw new NotImplementedException(); - public List Weights => throw new NotImplementedException(); + public List Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public List get_weights() => throw new NotImplementedException(); + public void set_weights(List weights) => throw new NotImplementedException(); public List NonTrainableWeights => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException(); @@ -84,8 +88,6 @@ namespace Tensorflow protected bool built = false; public bool Built => built; - List ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } - public RnnCell(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 0f809cba..39ca1b35 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -120,6 +120,30 @@ namespace Tensorflow.Keras.Engine } } + public virtual void set_weights(List weights) + { + if (Weights.Count() != weights.Count()) throw new ValueError( + $"You called `set_weights` on layer \"{this.name}\"" + + $"with a weight list of length {len(weights)}, but the layer was " + + $"expecting {len(Weights)} weights."); + for (int i = 0; i < weights.Count(); i++) + { + if (weights[i].shape != Weights[i].shape) + { + throw new ValueError($"Layer weight shape {weights[i].shape} not compatible with provided weight shape {Weights[i].shape}"); + } + } + foreach (var (this_w, v_w) in zip(Weights, weights)) + this_w.assign(v_w, read_value: true); + } + + public List get_weights() + { + List weights = new List(); + weights.AddRange(Weights.ConvertAll(x => x.numpy())); + return weights; + } + protected int id; public int Id => id; protected string name; From 426a55ce7b4d88bde0063ae9a7bd12e9262c9d14 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Fri, 21 Apr 2023 15:02:38 +0800 Subject: [PATCH 2/2] Add set_weights and get_weights APIs --- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 2 +- .../Operations/NnOps/RNNCell.cs | 2 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 47 ++++++++++++++++--- 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f16d54d1..1e473d75 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras List TrainableWeights { get; } List NonTrainableWeights { get; } List Weights { get; set; } - void set_weights(List weights); + void set_weights(IEnumerable weights); List get_weights(); Shape OutputShape { get; } Shape BatchInputShape { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 93e0edf0..5847e31a 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -75,7 +75,7 @@ namespace Tensorflow public List Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } public List get_weights() => throw new NotImplementedException(); - public void set_weights(List weights) => throw new NotImplementedException(); + public void set_weights(IEnumerable weights) => throw new NotImplementedException(); public List NonTrainableWeights => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 31ac74dc..11a0584c 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -30,6 +30,9 @@ using Tensorflow.Training; using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Util; using static Tensorflow.Binding; +using Tensorflow.Framework; +using Tensorflow.Sessions; + namespace Tensorflow.Keras.Engine { @@ -134,21 +137,53 @@ namespace Tensorflow.Keras.Engine } } - public virtual void set_weights(List weights) + public virtual void set_weights(IEnumerable weights) { if (Weights.Count() != weights.Count()) throw new ValueError( $"You called `set_weights` on layer \"{this.name}\"" + $"with a weight list of length {len(weights)}, but the layer was " + $"expecting {len(Weights)} weights."); - for (int i = 0; i < weights.Count(); i++) + + + + // check if the shapes are compatible + var weight_index = 0; + foreach(var w in weights) { - if (weights[i].shape != Weights[i].shape) + if (!Weights[weight_index].AsTensor().is_compatible_with(w)) { - throw new ValueError($"Layer weight shape {weights[i].shape} not compatible with provided weight shape {Weights[i].shape}"); + throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}"); } + weight_index++; + } + + if (tf.executing_eagerly()) + { + foreach (var (this_w, v_w) in zip(Weights, weights)) + this_w.assign(v_w, read_value: true); + } + else + { + // TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed. + + //Tensors assign_ops = new Tensors(); + //var feed_dict = new FeedDict(); + + //Graph g = tf.Graph().as_default(); + //foreach (var (this_w, v_w) in zip(Weights, weights)) + //{ + // var tf_dtype = this_w.dtype; + // var placeholder_shape = v_w.shape; + // var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape); + // var assign_op = this_w.assign(assign_placeholder); + // assign_ops.Add(assign_op); + // feed_dict.Add(assign_placeholder, v_w); + //} + //var sess = tf.Session().as_default(); + //sess.run(assign_ops, feed_dict); + + //g.Exit(); } - foreach (var (this_w, v_w) in zip(Weights, weights)) - this_w.assign(v_w, read_value: true); } public List get_weights()