Add set_weights and get_weighst APIstags/v0.100.5-BERT-load
@@ -1,5 +1,6 @@ | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Training; | |||
namespace Tensorflow.Keras | |||
@@ -18,6 +19,8 @@ namespace Tensorflow.Keras | |||
List<IVariableV1> TrainableWeights { get; } | |||
List<IVariableV1> NonTrainableWeights { get; } | |||
List<IVariableV1> Weights { get; set; } | |||
void set_weights(IEnumerable<NDArray> weights); | |||
List<NDArray> get_weights(); | |||
Shape OutputShape { get; } | |||
Shape BatchInputShape { get; } | |||
TensorShapeConfig BuildInputShape { get; } | |||
@@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Train; | |||
using Tensorflow.Util; | |||
@@ -71,7 +72,10 @@ namespace Tensorflow | |||
public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | |||
public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | |||
public List<IVariableV1> Weights => throw new NotImplementedException(); | |||
public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } | |||
public List<NDArray> get_weights() => throw new NotImplementedException(); | |||
public void set_weights(IEnumerable<NDArray> weights) => throw new NotImplementedException(); | |||
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | |||
public Shape OutputShape => throw new NotImplementedException(); | |||
@@ -84,8 +88,6 @@ namespace Tensorflow | |||
protected bool built = false; | |||
public bool Built => built; | |||
List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } | |||
public RnnCell(bool trainable = true, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
@@ -30,6 +30,9 @@ using Tensorflow.Training; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Sessions; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -134,6 +137,62 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
} | |||
public virtual void set_weights(IEnumerable<NDArray> weights) | |||
{ | |||
if (Weights.Count() != weights.Count()) throw new ValueError( | |||
$"You called `set_weights` on layer \"{this.name}\"" + | |||
$"with a weight list of length {len(weights)}, but the layer was " + | |||
$"expecting {len(Weights)} weights."); | |||
// check if the shapes are compatible | |||
var weight_index = 0; | |||
foreach(var w in weights) | |||
{ | |||
if (!Weights[weight_index].AsTensor().is_compatible_with(w)) | |||
{ | |||
throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}"); | |||
} | |||
weight_index++; | |||
} | |||
if (tf.executing_eagerly()) | |||
{ | |||
foreach (var (this_w, v_w) in zip(Weights, weights)) | |||
this_w.assign(v_w, read_value: true); | |||
} | |||
else | |||
{ | |||
// TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed. | |||
//Tensors assign_ops = new Tensors(); | |||
//var feed_dict = new FeedDict(); | |||
//Graph g = tf.Graph().as_default(); | |||
//foreach (var (this_w, v_w) in zip(Weights, weights)) | |||
//{ | |||
// var tf_dtype = this_w.dtype; | |||
// var placeholder_shape = v_w.shape; | |||
// var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape); | |||
// var assign_op = this_w.assign(assign_placeholder); | |||
// assign_ops.Add(assign_op); | |||
// feed_dict.Add(assign_placeholder, v_w); | |||
//} | |||
//var sess = tf.Session().as_default(); | |||
//sess.run(assign_ops, feed_dict); | |||
//g.Exit(); | |||
} | |||
} | |||
public List<NDArray> get_weights() | |||
{ | |||
List<NDArray > weights = new List<NDArray>(); | |||
weights.AddRange(Weights.ConvertAll(x => x.numpy())); | |||
return weights; | |||
} | |||
protected int id; | |||
public int Id => id; | |||
protected string name; | |||