Browse Source

add VariableArgs.

tags/v0.20
Oceania2018 5 years ago
parent
commit
29042df369
7 changed files with 77 additions and 57 deletions
  1. +14
    -9
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  2. +15
    -12
      src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs
  3. +10
    -16
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  4. +6
    -6
      src/TensorFlowNET.Core/Layers/Layer.cs
  5. +5
    -14
      src/TensorFlowNET.Core/Training/Trackable.cs
  6. +26
    -0
      src/TensorFlowNET.Core/Variables/VariableArgs.cs
  7. +1
    -0
      src/TensorFlowNET.Core/tensorflow.cs

+ 14
- 9
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -250,7 +250,7 @@ namespace Tensorflow.Keras.Engine
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool? trainable = null,
Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null)
Func<VariableArgs, IVariableV1> getter = null)
{
if (dtype == TF_DataType.DtInvalid)
dtype = TF_DataType.TF_FLOAT;
@@ -259,7 +259,7 @@ namespace Tensorflow.Keras.Engine
trainable = true;

// Initialize variable when no initializer provided
if(initializer == null)
if (initializer == null)
{
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if (dtype.is_floating())
@@ -269,13 +269,18 @@ namespace Tensorflow.Keras.Engine
else
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}");
}
var variable = _add_variable_with_custom_getter(name,
shape,
dtype: dtype,
getter: (getter == null) ? base_layer_utils.make_variable : getter,
overwrite: true,
initializer: initializer,
trainable: trainable.Value);

var variable = _add_variable_with_custom_getter(new VariableArgs
{
Name = name,
Shape = shape,
DType = dtype,
Getter = getter ?? base_layer_utils.make_variable,
Overwrite = true,
Initializer = initializer,
Trainable = trainable.Value
});

//backend.track_variable(variable);
if (trainable == true)
_trainable_weights.Add(variable);


+ 15
- 12
src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs View File

@@ -199,8 +199,8 @@ namespace Tensorflow.Keras.Optimizers
}
}

ResourceVariable add_weight(string name,
TensorShape shape,
ResourceVariable add_weight(string name,
TensorShape shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = false,
@@ -213,16 +213,19 @@ namespace Tensorflow.Keras.Optimizers
if (dtype == TF_DataType.DtInvalid)
dtype = TF_DataType.TF_FLOAT;

var variable = _add_variable_with_custom_getter(name: name,
shape: shape,
getter: base_layer_utils.make_variable,
dtype: dtype,
overwrite: true,
initializer: initializer,
trainable: trainable,
use_resource: true,
synchronization: synchronization,
aggregation: aggregation);
var variable = _add_variable_with_custom_getter(new VariableArgs
{
Name = name,
Shape = shape,
Getter = base_layer_utils.make_variable,
DType = dtype,
Overwrite = true,
Initializer = initializer,
Trainable = trainable,
UseResource = true,
Synchronization = synchronization,
Aggregation = aggregation
});

return variable as ResourceVariable;
}


+ 10
- 16
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -26,32 +26,26 @@ namespace Tensorflow.Keras.Utils
/// <summary>
/// Adds a new variable to the layer.
/// </summary>
/// <param name="name"></param>
/// <param name="shape"></param>
/// <param name="dtype"></param>
/// <param name="initializer"></param>
/// <param name="trainable"></param>
/// <param name="args"></param>
/// <returns></returns>
public static IVariableV1 make_variable(string name,
int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = true)
public static IVariableV1 make_variable(VariableArgs args)
{
#pragma warning disable CS0219 // Variable is assigned but its value is never used
var initializing_from_value = false;
bool use_resource = true;
#pragma warning restore CS0219 // Variable is assigned but its value is never used

ops.init_scope();

Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);
Func<Tensor> init_val = () => args.Initializer.call(args.Shape, dtype: args.DType);

var variable_dtype = dtype.as_base_dtype();
var variable_dtype = args.DType.as_base_dtype();
var v = tf.Variable(init_val,
dtype: dtype,
shape: shape,
name: name);
dtype: args.DType,
shape: args.Shape,
name: args.Name,
trainable: args.Trainable,
validate_shape: args.ValidateShape,
use_resource: args.UseResource);

return v;
}


+ 6
- 6
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -167,12 +167,12 @@ namespace Tensorflow.Layers
dtype: dtype,
initializer: initializer,
trainable: trainable,
getter: (name1, shape1, dtype1, initializer1, trainable1) =>
tf.compat.v1.get_variable(name1,
shape: new TensorShape(shape1),
dtype: dtype1,
initializer: initializer1,
trainable: trainable1)
getter: (args) =>
tf.compat.v1.get_variable(args.Name,
shape: args.Shape,
dtype: args.DType,
initializer: args.Initializer,
trainable: args.Trainable)
);

//if (init_graph != null)


+ 5
- 14
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -27,16 +27,7 @@ namespace Tensorflow.Train
/// Restore-on-create for a variable be saved with this `Checkpointable`.
/// </summary>
/// <returns></returns>
protected virtual IVariableV1 _add_variable_with_custom_getter(string name,
int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null,
bool overwrite = false,
bool trainable = false,
bool use_resource = false,
VariableSynchronization synchronization = VariableSynchronization.Auto,
VariableAggregation aggregation = VariableAggregation.None)
protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args)
{
ops.init_scope();
#pragma warning disable CS0219 // Variable is assigned but its value is never used
@@ -50,15 +41,15 @@ namespace Tensorflow.Train
checkpoint_initializer = null;

IVariableV1 new_variable;
new_variable = getter(name, shape, dtype, initializer, trainable);
new_variable = args.Getter(args);

// If we set an initializer and the variable processed it, tracking will not
// assign again. It will add this variable to our dependencies, and if there
// is a non-trivial restoration queued, it will handle that. This also
// handles slot variables.
if (!overwrite || new_variable is RefVariable)
return _track_checkpointable(new_variable, name: name,
overwrite: overwrite);
if (!args.Overwrite || new_variable is RefVariable)
return _track_checkpointable(new_variable, name: args.Name,
overwrite: args.Overwrite);
else
return new_variable;
}


+ 26
- 0
src/TensorFlowNET.Core/Variables/VariableArgs.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class VariableArgs
{
public object InitialValue { get; set; }
public Func<VariableArgs, IVariableV1> Getter { get; set; }
public string Name { get; set; }
public TensorShape Shape { get; set; }
public TF_DataType DType { get; set; } = TF_DataType.DtInvalid;
public IInitializer Initializer { get; set; }
public bool Trainable { get; set; }
public bool ValidateShape { get; set; } = true;
public bool UseResource { get; set; } = true;
public bool Overwrite { get; set; }
public List<string> Collections { get; set; }
public string CachingDevice { get; set; } = "";
public VariableDef VariableDef { get; set; }
public string ImportScope { get; set; } = "";
public VariableSynchronization Synchronization { get; set; } = VariableSynchronization.Auto;
public VariableAggregation Aggregation { get; set; } = VariableAggregation.None;
}
}

+ 1
- 0
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -62,6 +62,7 @@ namespace Tensorflow
public ResourceVariable Variable<T>(T data,
bool trainable = true,
bool validate_shape = true,
bool use_resource = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int[] shape = null)


Loading…
Cancel
Save