From 78bd4c758e1e4f2cf2025c6a630cbcf6e419f709 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Thu, 20 Apr 2023 10:49:02 +0800 Subject: [PATCH] Add api set_weights and get_weights --- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 3 +++ .../Operations/NnOps/RNNCell.cs | 8 ++++--- src/TensorFlowNET.Keras/Engine/Layer.cs | 24 +++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 9d69d5d0..f16d54d1 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,5 +1,6 @@ using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; using Tensorflow.Training; namespace Tensorflow.Keras @@ -18,6 +19,8 @@ namespace Tensorflow.Keras List TrainableWeights { get; } List NonTrainableWeights { get; } List Weights { get; set; } + void set_weights(List weights); + List get_weights(); Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index bc4daf13..93e0edf0 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; using Tensorflow.Operations; using Tensorflow.Train; using Tensorflow.Util; @@ -71,7 +72,10 @@ namespace Tensorflow public List TrainableVariables => throw new NotImplementedException(); public List TrainableWeights => throw new NotImplementedException(); - public List Weights => throw new NotImplementedException(); + public List Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public List get_weights() => throw new NotImplementedException(); + public void set_weights(List weights) => throw new NotImplementedException(); public List NonTrainableWeights => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException(); @@ -84,8 +88,6 @@ namespace Tensorflow protected bool built = false; public bool Built => built; - List ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } - public RnnCell(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 0f809cba..39ca1b35 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -120,6 +120,30 @@ namespace Tensorflow.Keras.Engine } } + public virtual void set_weights(List weights) + { + if (Weights.Count() != weights.Count()) throw new ValueError( + $"You called `set_weights` on layer \"{this.name}\"" + + $"with a weight list of length {len(weights)}, but the layer was " + + $"expecting {len(Weights)} weights."); + for (int i = 0; i < weights.Count(); i++) + { + if (weights[i].shape != Weights[i].shape) + { + throw new ValueError($"Layer weight shape {weights[i].shape} not compatible with provided weight shape {Weights[i].shape}"); + } + } + foreach (var (this_w, v_w) in zip(Weights, weights)) + this_w.assign(v_w, read_value: true); + } + + public List get_weights() + { + List weights = new List(); + weights.AddRange(Weights.ConvertAll(x => x.numpy())); + return weights; + } + protected int id; public int Id => id; protected string name;