@@ -250,7 +250,7 @@ namespace Tensorflow.Keras.Engine | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
IInitializer initializer = null, | IInitializer initializer = null, | ||||
bool? trainable = null, | bool? trainable = null, | ||||
Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null) | |||||
Func<VariableArgs, IVariableV1> getter = null) | |||||
{ | { | ||||
if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
dtype = TF_DataType.TF_FLOAT; | dtype = TF_DataType.TF_FLOAT; | ||||
@@ -259,7 +259,7 @@ namespace Tensorflow.Keras.Engine | |||||
trainable = true; | trainable = true; | ||||
// Initialize variable when no initializer provided | // Initialize variable when no initializer provided | ||||
if(initializer == null) | |||||
if (initializer == null) | |||||
{ | { | ||||
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer | // If dtype is DT_FLOAT, provide a uniform unit scaling initializer | ||||
if (dtype.is_floating()) | if (dtype.is_floating()) | ||||
@@ -269,13 +269,18 @@ namespace Tensorflow.Keras.Engine | |||||
else | else | ||||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}"); | throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}"); | ||||
} | } | ||||
var variable = _add_variable_with_custom_getter(name, | |||||
shape, | |||||
dtype: dtype, | |||||
getter: (getter == null) ? base_layer_utils.make_variable : getter, | |||||
overwrite: true, | |||||
initializer: initializer, | |||||
trainable: trainable.Value); | |||||
var variable = _add_variable_with_custom_getter(new VariableArgs | |||||
{ | |||||
Name = name, | |||||
Shape = shape, | |||||
DType = dtype, | |||||
Getter = getter ?? base_layer_utils.make_variable, | |||||
Overwrite = true, | |||||
Initializer = initializer, | |||||
Trainable = trainable.Value | |||||
}); | |||||
//backend.track_variable(variable); | //backend.track_variable(variable); | ||||
if (trainable == true) | if (trainable == true) | ||||
_trainable_weights.Add(variable); | _trainable_weights.Add(variable); | ||||
@@ -199,8 +199,8 @@ namespace Tensorflow.Keras.Optimizers | |||||
} | } | ||||
} | } | ||||
ResourceVariable add_weight(string name, | |||||
TensorShape shape, | |||||
ResourceVariable add_weight(string name, | |||||
TensorShape shape, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
IInitializer initializer = null, | IInitializer initializer = null, | ||||
bool trainable = false, | bool trainable = false, | ||||
@@ -213,16 +213,19 @@ namespace Tensorflow.Keras.Optimizers | |||||
if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
dtype = TF_DataType.TF_FLOAT; | dtype = TF_DataType.TF_FLOAT; | ||||
var variable = _add_variable_with_custom_getter(name: name, | |||||
shape: shape, | |||||
getter: base_layer_utils.make_variable, | |||||
dtype: dtype, | |||||
overwrite: true, | |||||
initializer: initializer, | |||||
trainable: trainable, | |||||
use_resource: true, | |||||
synchronization: synchronization, | |||||
aggregation: aggregation); | |||||
var variable = _add_variable_with_custom_getter(new VariableArgs | |||||
{ | |||||
Name = name, | |||||
Shape = shape, | |||||
Getter = base_layer_utils.make_variable, | |||||
DType = dtype, | |||||
Overwrite = true, | |||||
Initializer = initializer, | |||||
Trainable = trainable, | |||||
UseResource = true, | |||||
Synchronization = synchronization, | |||||
Aggregation = aggregation | |||||
}); | |||||
return variable as ResourceVariable; | return variable as ResourceVariable; | ||||
} | } | ||||
@@ -26,32 +26,26 @@ namespace Tensorflow.Keras.Utils | |||||
/// <summary> | /// <summary> | ||||
/// Adds a new variable to the layer. | /// Adds a new variable to the layer. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="name"></param> | |||||
/// <param name="shape"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="initializer"></param> | |||||
/// <param name="trainable"></param> | |||||
/// <param name="args"></param> | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public static IVariableV1 make_variable(string name, | |||||
int[] shape, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
IInitializer initializer = null, | |||||
bool trainable = true) | |||||
public static IVariableV1 make_variable(VariableArgs args) | |||||
{ | { | ||||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | #pragma warning disable CS0219 // Variable is assigned but its value is never used | ||||
var initializing_from_value = false; | var initializing_from_value = false; | ||||
bool use_resource = true; | |||||
#pragma warning restore CS0219 // Variable is assigned but its value is never used | #pragma warning restore CS0219 // Variable is assigned but its value is never used | ||||
ops.init_scope(); | ops.init_scope(); | ||||
Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); | |||||
Func<Tensor> init_val = () => args.Initializer.call(args.Shape, dtype: args.DType); | |||||
var variable_dtype = dtype.as_base_dtype(); | |||||
var variable_dtype = args.DType.as_base_dtype(); | |||||
var v = tf.Variable(init_val, | var v = tf.Variable(init_val, | ||||
dtype: dtype, | |||||
shape: shape, | |||||
name: name); | |||||
dtype: args.DType, | |||||
shape: args.Shape, | |||||
name: args.Name, | |||||
trainable: args.Trainable, | |||||
validate_shape: args.ValidateShape, | |||||
use_resource: args.UseResource); | |||||
return v; | return v; | ||||
} | } | ||||
@@ -167,12 +167,12 @@ namespace Tensorflow.Layers | |||||
dtype: dtype, | dtype: dtype, | ||||
initializer: initializer, | initializer: initializer, | ||||
trainable: trainable, | trainable: trainable, | ||||
getter: (name1, shape1, dtype1, initializer1, trainable1) => | |||||
tf.compat.v1.get_variable(name1, | |||||
shape: new TensorShape(shape1), | |||||
dtype: dtype1, | |||||
initializer: initializer1, | |||||
trainable: trainable1) | |||||
getter: (args) => | |||||
tf.compat.v1.get_variable(args.Name, | |||||
shape: args.Shape, | |||||
dtype: args.DType, | |||||
initializer: args.Initializer, | |||||
trainable: args.Trainable) | |||||
); | ); | ||||
//if (init_graph != null) | //if (init_graph != null) | ||||
@@ -27,16 +27,7 @@ namespace Tensorflow.Train | |||||
/// Restore-on-create for a variable be saved with this `Checkpointable`. | /// Restore-on-create for a variable be saved with this `Checkpointable`. | ||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
protected virtual IVariableV1 _add_variable_with_custom_getter(string name, | |||||
int[] shape, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
IInitializer initializer = null, | |||||
Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null, | |||||
bool overwrite = false, | |||||
bool trainable = false, | |||||
bool use_resource = false, | |||||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||||
VariableAggregation aggregation = VariableAggregation.None) | |||||
protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args) | |||||
{ | { | ||||
ops.init_scope(); | ops.init_scope(); | ||||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | #pragma warning disable CS0219 // Variable is assigned but its value is never used | ||||
@@ -50,15 +41,15 @@ namespace Tensorflow.Train | |||||
checkpoint_initializer = null; | checkpoint_initializer = null; | ||||
IVariableV1 new_variable; | IVariableV1 new_variable; | ||||
new_variable = getter(name, shape, dtype, initializer, trainable); | |||||
new_variable = args.Getter(args); | |||||
// If we set an initializer and the variable processed it, tracking will not | // If we set an initializer and the variable processed it, tracking will not | ||||
// assign again. It will add this variable to our dependencies, and if there | // assign again. It will add this variable to our dependencies, and if there | ||||
// is a non-trivial restoration queued, it will handle that. This also | // is a non-trivial restoration queued, it will handle that. This also | ||||
// handles slot variables. | // handles slot variables. | ||||
if (!overwrite || new_variable is RefVariable) | |||||
return _track_checkpointable(new_variable, name: name, | |||||
overwrite: overwrite); | |||||
if (!args.Overwrite || new_variable is RefVariable) | |||||
return _track_checkpointable(new_variable, name: args.Name, | |||||
overwrite: args.Overwrite); | |||||
else | else | ||||
return new_variable; | return new_variable; | ||||
} | } | ||||
@@ -0,0 +1,26 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class VariableArgs | |||||
{ | |||||
public object InitialValue { get; set; } | |||||
public Func<VariableArgs, IVariableV1> Getter { get; set; } | |||||
public string Name { get; set; } | |||||
public TensorShape Shape { get; set; } | |||||
public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; | |||||
public IInitializer Initializer { get; set; } | |||||
public bool Trainable { get; set; } | |||||
public bool ValidateShape { get; set; } = true; | |||||
public bool UseResource { get; set; } = true; | |||||
public bool Overwrite { get; set; } | |||||
public List<string> Collections { get; set; } | |||||
public string CachingDevice { get; set; } = ""; | |||||
public VariableDef VariableDef { get; set; } | |||||
public string ImportScope { get; set; } = ""; | |||||
public VariableSynchronization Synchronization { get; set; } = VariableSynchronization.Auto; | |||||
public VariableAggregation Aggregation { get; set; } = VariableAggregation.None; | |||||
} | |||||
} |
@@ -62,6 +62,7 @@ namespace Tensorflow | |||||
public ResourceVariable Variable<T>(T data, | public ResourceVariable Variable<T>(T data, | ||||
bool trainable = true, | bool trainable = true, | ||||
bool validate_shape = true, | bool validate_shape = true, | ||||
bool use_resource = true, | |||||
string name = null, | string name = null, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
int[] shape = null) | int[] shape = null) | ||||