diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 9d69d5d0..1e473d75 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,5 +1,6 @@ using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; using Tensorflow.Training; namespace Tensorflow.Keras @@ -18,6 +19,8 @@ namespace Tensorflow.Keras List TrainableWeights { get; } List NonTrainableWeights { get; } List Weights { get; set; } + void set_weights(IEnumerable weights); + List get_weights(); Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index bc4daf13..5847e31a 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; using Tensorflow.Operations; using Tensorflow.Train; using Tensorflow.Util; @@ -71,7 +72,10 @@ namespace Tensorflow public List TrainableVariables => throw new NotImplementedException(); public List TrainableWeights => throw new NotImplementedException(); - public List Weights => throw new NotImplementedException(); + public List Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public List get_weights() => throw new NotImplementedException(); + public void set_weights(IEnumerable weights) => throw new NotImplementedException(); public List NonTrainableWeights => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException(); @@ -84,8 +88,6 @@ namespace Tensorflow protected bool built = false; public bool Built => built; - List ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } - public RnnCell(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 79c955b6..11a0584c 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -30,6 +30,9 @@ using Tensorflow.Training; using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Util; using static Tensorflow.Binding; +using Tensorflow.Framework; +using Tensorflow.Sessions; + namespace Tensorflow.Keras.Engine { @@ -134,6 +137,62 @@ namespace Tensorflow.Keras.Engine } } + public virtual void set_weights(IEnumerable weights) + { + if (Weights.Count() != weights.Count()) throw new ValueError( + $"You called `set_weights` on layer \"{this.name}\"" + + $"with a weight list of length {len(weights)}, but the layer was " + + $"expecting {len(Weights)} weights."); + + + + // check if the shapes are compatible + var weight_index = 0; + foreach(var w in weights) + { + if (!Weights[weight_index].AsTensor().is_compatible_with(w)) + { + throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}"); + } + weight_index++; + } + + if (tf.executing_eagerly()) + { + foreach (var (this_w, v_w) in zip(Weights, weights)) + this_w.assign(v_w, read_value: true); + } + else + { + // TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed. + + //Tensors assign_ops = new Tensors(); + //var feed_dict = new FeedDict(); + + //Graph g = tf.Graph().as_default(); + //foreach (var (this_w, v_w) in zip(Weights, weights)) + //{ + // var tf_dtype = this_w.dtype; + // var placeholder_shape = v_w.shape; + // var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape); + // var assign_op = this_w.assign(assign_placeholder); + // assign_ops.Add(assign_op); + // feed_dict.Add(assign_placeholder, v_w); + //} + //var sess = tf.Session().as_default(); + //sess.run(assign_ops, feed_dict); + + //g.Exit(); + } + } + + public List get_weights() + { + List weights = new List(); + weights.AddRange(Weights.ConvertAll(x => x.numpy())); + return weights; + } + protected int id; public int Id => id; protected string name;