diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 60379a0c..349d7908 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -250,7 +250,7 @@ namespace Tensorflow.Keras.Engine TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, bool? trainable = null, - Func getter = null) + Func getter = null) { if (dtype == TF_DataType.DtInvalid) dtype = TF_DataType.TF_FLOAT; @@ -259,7 +259,7 @@ namespace Tensorflow.Keras.Engine trainable = true; // Initialize variable when no initializer provided - if(initializer == null) + if (initializer == null) { // If dtype is DT_FLOAT, provide a uniform unit scaling initializer if (dtype.is_floating()) @@ -269,13 +269,18 @@ namespace Tensorflow.Keras.Engine else throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}"); } - var variable = _add_variable_with_custom_getter(name, - shape, - dtype: dtype, - getter: (getter == null) ? base_layer_utils.make_variable : getter, - overwrite: true, - initializer: initializer, - trainable: trainable.Value); + + var variable = _add_variable_with_custom_getter(new VariableArgs + { + Name = name, + Shape = shape, + DType = dtype, + Getter = getter ?? base_layer_utils.make_variable, + Overwrite = true, + Initializer = initializer, + Trainable = trainable.Value + }); + //backend.track_variable(variable); if (trainable == true) _trainable_weights.Add(variable); diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs index 5c75e9bf..6d29f95d 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs @@ -199,8 +199,8 @@ namespace Tensorflow.Keras.Optimizers } } - ResourceVariable add_weight(string name, - TensorShape shape, + ResourceVariable add_weight(string name, + TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, IInitializer initializer = null, bool trainable = false, @@ -213,16 +213,19 @@ namespace Tensorflow.Keras.Optimizers if (dtype == TF_DataType.DtInvalid) dtype = TF_DataType.TF_FLOAT; - var variable = _add_variable_with_custom_getter(name: name, - shape: shape, - getter: base_layer_utils.make_variable, - dtype: dtype, - overwrite: true, - initializer: initializer, - trainable: trainable, - use_resource: true, - synchronization: synchronization, - aggregation: aggregation); + var variable = _add_variable_with_custom_getter(new VariableArgs + { + Name = name, + Shape = shape, + Getter = base_layer_utils.make_variable, + DType = dtype, + Overwrite = true, + Initializer = initializer, + Trainable = trainable, + UseResource = true, + Synchronization = synchronization, + Aggregation = aggregation + }); return variable as ResourceVariable; } diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index 76a20bcf..a3667867 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -26,32 +26,26 @@ namespace Tensorflow.Keras.Utils /// /// Adds a new variable to the layer. /// - /// - /// - /// - /// - /// + /// /// - public static IVariableV1 make_variable(string name, - int[] shape, - TF_DataType dtype = TF_DataType.TF_FLOAT, - IInitializer initializer = null, - bool trainable = true) + public static IVariableV1 make_variable(VariableArgs args) { #pragma warning disable CS0219 // Variable is assigned but its value is never used var initializing_from_value = false; - bool use_resource = true; #pragma warning restore CS0219 // Variable is assigned but its value is never used ops.init_scope(); - Func init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); + Func init_val = () => args.Initializer.call(args.Shape, dtype: args.DType); - var variable_dtype = dtype.as_base_dtype(); + var variable_dtype = args.DType.as_base_dtype(); var v = tf.Variable(init_val, - dtype: dtype, - shape: shape, - name: name); + dtype: args.DType, + shape: args.Shape, + name: args.Name, + trainable: args.Trainable, + validate_shape: args.ValidateShape, + use_resource: args.UseResource); return v; } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 5c0ad97e..2d98d081 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -167,12 +167,12 @@ namespace Tensorflow.Layers dtype: dtype, initializer: initializer, trainable: trainable, - getter: (name1, shape1, dtype1, initializer1, trainable1) => - tf.compat.v1.get_variable(name1, - shape: new TensorShape(shape1), - dtype: dtype1, - initializer: initializer1, - trainable: trainable1) + getter: (args) => + tf.compat.v1.get_variable(args.Name, + shape: args.Shape, + dtype: args.DType, + initializer: args.Initializer, + trainable: args.Trainable) ); //if (init_graph != null) diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 518c530b..74fe0b69 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -27,16 +27,7 @@ namespace Tensorflow.Train /// Restore-on-create for a variable be saved with this `Checkpointable`. /// /// - protected virtual IVariableV1 _add_variable_with_custom_getter(string name, - int[] shape, - TF_DataType dtype = TF_DataType.TF_FLOAT, - IInitializer initializer = null, - Func getter = null, - bool overwrite = false, - bool trainable = false, - bool use_resource = false, - VariableSynchronization synchronization = VariableSynchronization.Auto, - VariableAggregation aggregation = VariableAggregation.None) + protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args) { ops.init_scope(); #pragma warning disable CS0219 // Variable is assigned but its value is never used @@ -50,15 +41,15 @@ namespace Tensorflow.Train checkpoint_initializer = null; IVariableV1 new_variable; - new_variable = getter(name, shape, dtype, initializer, trainable); + new_variable = args.Getter(args); // If we set an initializer and the variable processed it, tracking will not // assign again. It will add this variable to our dependencies, and if there // is a non-trivial restoration queued, it will handle that. This also // handles slot variables. - if (!overwrite || new_variable is RefVariable) - return _track_checkpointable(new_variable, name: name, - overwrite: overwrite); + if (!args.Overwrite || new_variable is RefVariable) + return _track_checkpointable(new_variable, name: args.Name, + overwrite: args.Overwrite); else return new_variable; } diff --git a/src/TensorFlowNET.Core/Variables/VariableArgs.cs b/src/TensorFlowNET.Core/Variables/VariableArgs.cs new file mode 100644 index 00000000..cbb7524a --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/VariableArgs.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class VariableArgs + { + public object InitialValue { get; set; } + public Func Getter { get; set; } + public string Name { get; set; } + public TensorShape Shape { get; set; } + public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; + public IInitializer Initializer { get; set; } + public bool Trainable { get; set; } + public bool ValidateShape { get; set; } = true; + public bool UseResource { get; set; } = true; + public bool Overwrite { get; set; } + public List Collections { get; set; } + public string CachingDevice { get; set; } = ""; + public VariableDef VariableDef { get; set; } + public string ImportScope { get; set; } = ""; + public VariableSynchronization Synchronization { get; set; } = VariableSynchronization.Auto; + public VariableAggregation Aggregation { get; set; } = VariableAggregation.None; + } +} diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 535be9ff..91917549 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -62,6 +62,7 @@ namespace Tensorflow public ResourceVariable Variable(T data, bool trainable = true, bool validate_shape = true, + bool use_resource = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null)