@@ -250,7 +250,7 @@ namespace Tensorflow.Keras.Engine | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
bool? trainable = null, | |||
Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null) | |||
Func<VariableArgs, IVariableV1> getter = null) | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = TF_DataType.TF_FLOAT; | |||
@@ -259,7 +259,7 @@ namespace Tensorflow.Keras.Engine | |||
trainable = true; | |||
// Initialize variable when no initializer provided | |||
if(initializer == null) | |||
if (initializer == null) | |||
{ | |||
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer | |||
if (dtype.is_floating()) | |||
@@ -269,13 +269,18 @@ namespace Tensorflow.Keras.Engine | |||
else | |||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}"); | |||
} | |||
var variable = _add_variable_with_custom_getter(name, | |||
shape, | |||
dtype: dtype, | |||
getter: (getter == null) ? base_layer_utils.make_variable : getter, | |||
overwrite: true, | |||
initializer: initializer, | |||
trainable: trainable.Value); | |||
var variable = _add_variable_with_custom_getter(new VariableArgs | |||
{ | |||
Name = name, | |||
Shape = shape, | |||
DType = dtype, | |||
Getter = getter ?? base_layer_utils.make_variable, | |||
Overwrite = true, | |||
Initializer = initializer, | |||
Trainable = trainable.Value | |||
}); | |||
//backend.track_variable(variable); | |||
if (trainable == true) | |||
_trainable_weights.Add(variable); | |||
@@ -199,8 +199,8 @@ namespace Tensorflow.Keras.Optimizers | |||
} | |||
} | |||
ResourceVariable add_weight(string name, | |||
TensorShape shape, | |||
ResourceVariable add_weight(string name, | |||
TensorShape shape, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
bool trainable = false, | |||
@@ -213,16 +213,19 @@ namespace Tensorflow.Keras.Optimizers | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = TF_DataType.TF_FLOAT; | |||
var variable = _add_variable_with_custom_getter(name: name, | |||
shape: shape, | |||
getter: base_layer_utils.make_variable, | |||
dtype: dtype, | |||
overwrite: true, | |||
initializer: initializer, | |||
trainable: trainable, | |||
use_resource: true, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
var variable = _add_variable_with_custom_getter(new VariableArgs | |||
{ | |||
Name = name, | |||
Shape = shape, | |||
Getter = base_layer_utils.make_variable, | |||
DType = dtype, | |||
Overwrite = true, | |||
Initializer = initializer, | |||
Trainable = trainable, | |||
UseResource = true, | |||
Synchronization = synchronization, | |||
Aggregation = aggregation | |||
}); | |||
return variable as ResourceVariable; | |||
} | |||
@@ -26,32 +26,26 @@ namespace Tensorflow.Keras.Utils | |||
/// <summary> | |||
/// Adds a new variable to the layer. | |||
/// </summary> | |||
/// <param name="name"></param> | |||
/// <param name="shape"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="initializer"></param> | |||
/// <param name="trainable"></param> | |||
/// <param name="args"></param> | |||
/// <returns></returns> | |||
public static IVariableV1 make_variable(string name, | |||
int[] shape, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
bool trainable = true) | |||
public static IVariableV1 make_variable(VariableArgs args) | |||
{ | |||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | |||
var initializing_from_value = false; | |||
bool use_resource = true; | |||
#pragma warning restore CS0219 // Variable is assigned but its value is never used | |||
ops.init_scope(); | |||
Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); | |||
Func<Tensor> init_val = () => args.Initializer.call(args.Shape, dtype: args.DType); | |||
var variable_dtype = dtype.as_base_dtype(); | |||
var variable_dtype = args.DType.as_base_dtype(); | |||
var v = tf.Variable(init_val, | |||
dtype: dtype, | |||
shape: shape, | |||
name: name); | |||
dtype: args.DType, | |||
shape: args.Shape, | |||
name: args.Name, | |||
trainable: args.Trainable, | |||
validate_shape: args.ValidateShape, | |||
use_resource: args.UseResource); | |||
return v; | |||
} | |||
@@ -167,12 +167,12 @@ namespace Tensorflow.Layers | |||
dtype: dtype, | |||
initializer: initializer, | |||
trainable: trainable, | |||
getter: (name1, shape1, dtype1, initializer1, trainable1) => | |||
tf.compat.v1.get_variable(name1, | |||
shape: new TensorShape(shape1), | |||
dtype: dtype1, | |||
initializer: initializer1, | |||
trainable: trainable1) | |||
getter: (args) => | |||
tf.compat.v1.get_variable(args.Name, | |||
shape: args.Shape, | |||
dtype: args.DType, | |||
initializer: args.Initializer, | |||
trainable: args.Trainable) | |||
); | |||
//if (init_graph != null) | |||
@@ -27,16 +27,7 @@ namespace Tensorflow.Train | |||
/// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
/// </summary> | |||
/// <returns></returns> | |||
protected virtual IVariableV1 _add_variable_with_custom_getter(string name, | |||
int[] shape, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null, | |||
bool overwrite = false, | |||
bool trainable = false, | |||
bool use_resource = false, | |||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
VariableAggregation aggregation = VariableAggregation.None) | |||
protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args) | |||
{ | |||
ops.init_scope(); | |||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | |||
@@ -50,15 +41,15 @@ namespace Tensorflow.Train | |||
checkpoint_initializer = null; | |||
IVariableV1 new_variable; | |||
new_variable = getter(name, shape, dtype, initializer, trainable); | |||
new_variable = args.Getter(args); | |||
// If we set an initializer and the variable processed it, tracking will not | |||
// assign again. It will add this variable to our dependencies, and if there | |||
// is a non-trivial restoration queued, it will handle that. This also | |||
// handles slot variables. | |||
if (!overwrite || new_variable is RefVariable) | |||
return _track_checkpointable(new_variable, name: name, | |||
overwrite: overwrite); | |||
if (!args.Overwrite || new_variable is RefVariable) | |||
return _track_checkpointable(new_variable, name: args.Name, | |||
overwrite: args.Overwrite); | |||
else | |||
return new_variable; | |||
} | |||
@@ -0,0 +1,26 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class VariableArgs | |||
{ | |||
public object InitialValue { get; set; } | |||
public Func<VariableArgs, IVariableV1> Getter { get; set; } | |||
public string Name { get; set; } | |||
public TensorShape Shape { get; set; } | |||
public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; | |||
public IInitializer Initializer { get; set; } | |||
public bool Trainable { get; set; } | |||
public bool ValidateShape { get; set; } = true; | |||
public bool UseResource { get; set; } = true; | |||
public bool Overwrite { get; set; } | |||
public List<string> Collections { get; set; } | |||
public string CachingDevice { get; set; } = ""; | |||
public VariableDef VariableDef { get; set; } | |||
public string ImportScope { get; set; } = ""; | |||
public VariableSynchronization Synchronization { get; set; } = VariableSynchronization.Auto; | |||
public VariableAggregation Aggregation { get; set; } = VariableAggregation.None; | |||
} | |||
} |
@@ -62,6 +62,7 @@ namespace Tensorflow | |||
public ResourceVariable Variable<T>(T data, | |||
bool trainable = true, | |||
bool validate_shape = true, | |||
bool use_resource = true, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
int[] shape = null) | |||