From d4f1c349bc5c6a49c74fffec532d8ade8332755b Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 19 Sep 2020 07:10:27 -0500 Subject: [PATCH] fix _apply_dense for Optimizer. --- src/TensorFlowNET.Core/APIs/tf.layers.cs | 3 +- src/TensorFlowNET.Core/APIs/tf.nn.cs | 4 +- .../Framework/meta_graph.cs | 7 +- .../Functions/c_api.function.cs | 17 ++++- src/TensorFlowNET.Core/Gradients/math_grad.cs | 11 ++- src/TensorFlowNET.Core/Gradients/nn_grad.cs | 10 +-- src/TensorFlowNET.Core/Graphs/Graph.cs | 6 -- src/TensorFlowNET.Core/Graphs/c_api.graph.cs | 2 +- .../TensorLikeDataAdapterArgs.cs | 16 +++++ .../Keras/Engine/DataAdapters/DataHandler.cs | 2 + .../DataAdapters/TensorLikeDataAdapter.cs | 3 +- .../Keras/Engine/Layer.State.cs | 21 ++++++ src/TensorFlowNET.Core/Keras/Engine/Model.cs | 35 +++++++++- .../Keras/Layers/Embedding.cs | 4 ++ .../Keras/Layers/LayersApi.cs | 9 ++- .../Initializers/TruncatedNormal.cs | 6 +- .../Operations/Operation.cs | 31 ++++++++- .../Operations/array_ops.cs | 66 +++++++++++++----- src/TensorFlowNET.Core/Operations/nn_ops.cs | 2 +- .../Training/AdamOptimizer.cs | 12 ++-- src/TensorFlowNET.Core/Training/Optimizer.cs | 17 ++++- .../Training/Saving/BaseSaverBuilder.cs | 2 +- .../Saving/ResourceVariableSaveable.cs | 1 + .../Saving/saveable_object_util.py.cs | 6 +- .../Training/gen_training_ops.cs | 6 +- .../Variables/BaseResourceVariable.cs | 15 +++- .../Variables/IVariableV1.cs | 6 ++ .../Variables/ResourceVariable.Implicit.cs | 14 ++-- .../Variables/ResourceVariable.cs | 69 +++++++++++-------- src/TensorFlowNET.Core/ops.cs | 2 +- src/TensorFlowNET.Core/tensorflow.cs | 2 + 31 files changed, 299 insertions(+), 108 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Layer.State.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 46772dc9..3485fbd5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -193,8 +193,7 @@ namespace Tensorflow Name = name }); - throw new NotImplementedException(""); - //return layer.apply(inputs).Item1; + return layer.Apply(inputs); } /// diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 3ab2a5b4..c3e01278 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -66,8 +66,8 @@ namespace Tensorflow Tensor keep = null; if (keep_prob != null) keep = 1.0f - keep_prob; - - return nn_ops.dropout_v2(x, rate: rate.Value, noise_shape: noise_shape, seed: seed, name: name); + var rate_tensor = rate.HasValue ? tf.constant(rate.Value) : keep; + return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name); } /// diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index 58add851..2716dfa2 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -150,7 +150,7 @@ namespace Tensorflow var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope: scope_to_prepend_to_names); var var_list = new Dictionary(); - // variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); + variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); return (var_list, imported_return_elements); } @@ -277,6 +277,11 @@ namespace Tensorflow var proto = x_ref_var.to_proto(export_scope); col_def.BytesList.Value.Add(proto.ToByteString()); } + else if(x is ResourceVariable x_res_var) + { + var proto = x_res_var.to_proto(export_scope); + col_def.BytesList.Value.Add(proto.ToByteString()); + } } break; case List collection_list: diff --git a/src/TensorFlowNET.Core/Functions/c_api.function.cs b/src/TensorFlowNET.Core/Functions/c_api.function.cs index 11ed7bdd..058fe7f2 100644 --- a/src/TensorFlowNET.Core/Functions/c_api.function.cs +++ b/src/TensorFlowNET.Core/Functions/c_api.function.cs @@ -31,8 +31,23 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern void TF_FunctionToFunctionDef(IntPtr func, IntPtr output_func_def, SafeStatusHandle status); + public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name, + bool append_hash_to_fn_name, + int num_opers, IntPtr[] opers, + int ninputs, TF_Output[] inputs, + int noutputs, TF_Output[] outputs, + IntPtr output_names, + IntPtr opts, + string description, + SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_FunctionName(IntPtr func); + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 7563ed4d..7622a6ae 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -327,8 +327,9 @@ namespace Tensorflow.Gradients var output_shape = op.outputs[0]._shape_tuple(); Tensor result, factor_tensor; - if(input_shape != null && - output_shape != null) + if(tf.executing_eagerly() + && input_shape != null + && output_shape != null) { var input_size = np.prod(input_shape); var output_size = np.prod(output_shape); @@ -339,11 +340,7 @@ namespace Tensorflow.Gradients { var input_shape_tensor = array_ops.shape(op.inputs[0]); var output_shape_tensor = array_ops.shape(op.outputs[0]); - var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); - throw new NotImplementedException(""); -#pragma warning disable CS0162 // Unreachable code detected - factor_tensor = null; -#pragma warning restore CS0162 // Unreachable code detected + factor_tensor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); } result = math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype)); diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index 6a2df6e9..e2564ff5 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -128,10 +128,10 @@ namespace Tensorflow.Gradients [RegisterGradient("Conv2D")] public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) { - var dilations = op.get_attr("dilations"); - var strides = op.get_attr("strides"); + var dilations = op.get_attr_list("dilations"); + var strides = op.get_attr_list("strides"); var padding = op.get_attr("padding"); - var explicit_paddings = op.get_attr("explicit_paddings"); + var explicit_paddings = op.get_attr_list("explicit_paddings"); var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); var data_format = op.get_attr("data_format"); var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); @@ -287,8 +287,8 @@ namespace Tensorflow.Gradients op.inputs[0], op.outputs[0], grad, - op.get_attr("ksize") as int[], - op.get_attr("strides") as int[], + op.get_attr_list("ksize"), + op.get_attr_list("strides"), padding: op.get_attr("padding").ToString(), data_format: op.get_attr("data_format").ToString()) }; diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 70d88a91..35275d4e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -293,12 +293,6 @@ namespace Tensorflow _create_op_helper(op, compute_device); - /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); - Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); - Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); - Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); - Console.WriteLine();*/ - return op; } diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 429c448d..471001bc 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -139,7 +139,7 @@ namespace Tensorflow /// TF_Status* [DllImport(TensorFlowLibName)] public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); - + /// /// Returns the number of dimensions of the Tensor referenced by `output` /// in `graph`. diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs new file mode 100644 index 00000000..891af9d9 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class TensorLikeDataAdapterArgs + { + public Tensor X { get; set; } + public Tensor Y { get; set; } + public int BatchSize { get; set; } + public int Steps { get; set; } + public int Epochs { get; set; } + public bool Shuffle { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs index de29d299..9a0351ea 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs @@ -27,7 +27,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters public DataHandler(DataHandlerArgs args) { + this.args = args; + var adapter_cls = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs { }); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index 0311219a..74c4be93 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.ArgsDefinition; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine.DataAdapters @@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters /// public class TensorLikeDataAdapter : IDataAdapter { - public TensorLikeDataAdapter() + public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) { tf.data.Dataset.range(5); } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.State.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.State.cs new file mode 100644 index 00000000..bb2036a5 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.State.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + Dictionary trainable_state; + Dictionary _get_trainable_state() + { + trainable_state = new Dictionary(); + throw new NotImplementedException(""); + } + + void _set_trainable_state(Dictionary trainable_state) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs index 8bba33fb..e4af7021 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs @@ -1,6 +1,7 @@ -using NumSharp; +using static Tensorflow.Binding; using System; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; @@ -21,6 +22,7 @@ namespace Tensorflow.Keras.Engine #pragma warning restore CS0108 // Member hides inherited member; missing new keyword string loss; IOptimizer optimizer; + IVariableV1 _steps_per_execution; public Model(ModelArgs args) : base(args) @@ -37,10 +39,25 @@ namespace Tensorflow.Keras.Engine break; } + int experimental_steps_per_execution = 1; + _configure_steps_per_execution(experimental_steps_per_execution); + + _reset_compile_cache(); + loss = lossName; _is_compiled = true; + } + + void _configure_steps_per_execution(int steps_per_execution) + { + _steps_per_execution = tf.Variable(steps_per_execution, + dtype: TF_DataType.TF_INT64, + aggregation: VariableAggregation.OnlyFirstReplica); + } + + void _reset_compile_cache() + { - // Prepare list of loss functions, same size of model outputs. } public void compile(string optimizerName, ILossFunc lossName) @@ -70,6 +87,20 @@ namespace Tensorflow.Keras.Engine int workers = 1, bool use_multiprocessing = false) { + var data_handler = new DataHandler(new DataHandlerArgs + { + X = x, + BatchSize = batch_size, + StepsPerEpoch = steps, + InitialEpoch = 0, + Epochs = 1, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + throw new NotImplementedException(""); } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index 9aa3747f..bbc9e66d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; @@ -44,6 +45,9 @@ namespace Tensorflow.Keras.Layers if (args.InputShape == null) args.InputShape = args.InputLength; + if (args.BatchInputShape == null) + args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); + embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer; SupportsMasking = mask_zero; } diff --git a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs index d1d876de..fc0b209f 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs @@ -34,10 +34,13 @@ namespace Tensorflow.Keras.Layers /// /// Turns positive integers (indexes) into dense vectors of fixed size. + /// This layer can only be used as the first layer in a model. + /// e.g. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]] + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding /// - /// - /// - /// + /// Size of the vocabulary, i.e. maximum integer index + 1. + /// Dimension of the dense embedding. + /// Initializer for the embeddings matrix (see keras.initializers). /// /// public Embedding Embedding(int input_dim, diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs index e656f7ea..048c11e7 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -36,9 +36,9 @@ namespace Tensorflow.Operations.Initializers public Tensor Apply(InitializerArgs args) { - if (args.DType == TF_DataType.DtInvalid) - args.DType = this.dtype; - return random_ops.truncated_normal(args.Shape, mean, stddev, dtype : dtype, seed: seed); + if (args.DType != TF_DataType.DtInvalid) + dtype = args.DType; + return random_ops.truncated_normal(args.Shape, mean, stddev, dtype: dtype, seed: seed); } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 595c0ce8..db528e70 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -230,6 +230,35 @@ namespace Tensorflow public virtual T get_attr(string name) => (T)get_attr(name); + public virtual T[] get_attr_list(string name) + { + if (tf.executing_eagerly()) + return (T[])get_attr(name); + + AttrValue x = null; + + lock (Locks.ProcessWide) + { + using var buf = new Buffer(); + c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); + tf.Status.Check(true); + + x = AttrValue.Parser.ParseFrom(buf.DangerousMemoryBlock.Stream()); + } + + string oneof_value = x.ValueCase.ToString(); + if (string.IsNullOrEmpty(oneof_value)) + return null; + + switch (typeof(T).Name) + { + case nameof(Int32): + return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); + default: + return null; + } + } + public virtual object get_attr(string name) { AttrValue x = null; @@ -250,7 +279,7 @@ namespace Tensorflow if (oneof_value == "list") throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); - if (oneof_value == "type") + if (string.Equals("type", oneof_value, StringComparison.OrdinalIgnoreCase)) return x.Type; object result = x.GetType().GetProperty(oneof_value).GetValue(x); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 7e508dce..7b0d6a94 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -85,26 +85,56 @@ namespace Tensorflow allow_broadcast: false); public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) - => tf_with(ops.name_scope(name, "zeros", shape), scope => + { + dtype = dtype.as_base_dtype(); + + if (tf.executing_eagerly()) { - dtype = dtype.as_base_dtype(); - name = scope; - var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); - Tensor zeros = null; - switch (dtype) + return tf_with(ops.name_scope(name, "zeros", shape), scope => { - case TF_DataType.TF_DOUBLE: - zeros = constant(0d); - break; - case TF_DataType.TF_FLOAT: - zeros = constant(0f); - break; - default: - zeros = constant(0); - break; - } - return fill(shape_tensor, zeros, name: name); - }); + name = scope; + var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); + Tensor zeros = null; + switch (dtype) + { + case TF_DataType.TF_DOUBLE: + zeros = constant(0d); + break; + case TF_DataType.TF_FLOAT: + zeros = constant(0f); + break; + default: + zeros = constant(0); + break; + } + return fill(shape_tensor, zeros, name: name); + }); + } + else + { + return tf_with(ops.name_scope(name, "zeros", shape), scope => + { + name = scope; + switch (dtype) + { + case TF_DataType.TF_BOOL: + return _constant_if_small(false, shape, dtype, name); + case TF_DataType.TF_DOUBLE: + return _constant_if_small(0.0D, shape, dtype, name); + case TF_DataType.TF_FLOAT: + return _constant_if_small(0.0F, shape, dtype, name); + case TF_DataType.TF_INT64: + return _constant_if_small(0l, shape, dtype, name); + case TF_DataType.TF_INT32: + return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_INT8: + return _constant_if_small(0, shape, dtype, name); + default: + throw new TypeError("can't find type for zeros"); + } + }); + } + } public static Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) { diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index ce3875cc..4c30c34e 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -68,7 +68,7 @@ namespace Tensorflow /// /// /// - public static Tensor dropout_v2(Tensor x, float rate, Tensor noise_shape = null, int? seed = null, string name = null) + public static Tensor dropout_v2(Tensor x, Tensor rate, Tensor noise_shape = null, int? seed = null, string name = null) { return tf_with(ops.name_scope(name, "dropout", x), scope => { diff --git a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs index 6f62fd27..47d4331c 100644 --- a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs @@ -60,17 +60,17 @@ namespace Tensorflow.Train }); } - public override Operation _apply_dense(Tensor grad, RefVariable var) + public override Operation _apply_dense(Tensor grad, ResourceVariable var) { var m = get_slot(var, "m"); var v = get_slot(var, "v"); var (beta1_power, beta2_power) = _get_beta_accumulators(); return gen_training_ops.apply_adam( - var, - m, - v, - math_ops.cast(beta1_power, var.dtype.as_base_dtype()), - math_ops.cast(beta2_power, var.dtype.as_base_dtype()), + var.Handle, + m.Handle, + v.Handle, + math_ops.cast(beta1_power.Handle, var.dtype.as_base_dtype()), + math_ops.cast(beta2_power.Handle, var.dtype.as_base_dtype()), math_ops.cast(_lr_t, var.dtype.as_base_dtype()), math_ops.cast(_beta1_t, var.dtype.as_base_dtype()), math_ops.cast(_beta2_t, var.dtype.as_base_dtype()), diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index 9019c146..c9c1673c 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -278,8 +278,16 @@ namespace Tensorflow public virtual Operation _apply_dense(Tensor grad, ResourceVariable var) { - var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); - return gen_training_ops.resource_apply_gradient_descent(var.Handle, alpha, grad, use_locking: _use_locking).op; + if (tf.executing_eagerly()) + { + var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); + return gen_training_ops.resource_apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; + } + else + { + var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); + return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; + } } public virtual Operation _apply_dense(Tensor grad, RefVariable var) @@ -314,6 +322,11 @@ namespace Tensorflow return _apply_sparse(gradient_no_duplicate_indices, var); } + public virtual Operation _apply_sparse(IndexedSlices grad, ResourceVariable var) + { + throw new NotImplementedException("_apply_sparse"); + } + public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) { throw new NotImplementedException("_apply_sparse"); diff --git a/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs index 8bc5811c..9bb763e3 100644 --- a/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs @@ -224,7 +224,7 @@ namespace Tensorflow var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length); idx += saveable.specs.Length; var restored = saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray()); - assign_ops.Add(restored as ITensorOrOperation); + assign_ops.Add(restored); } return control_flow_ops.group(assign_ops.ToArray(), name: name); diff --git a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs index 415671c2..d71ac4b9 100644 --- a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs +++ b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 19f04650..7c2d3330 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -67,9 +67,7 @@ namespace Tensorflow { ops.init_scope(); var variable = ops.internal_convert_to_tensor(op, as_ref: true); - if (variable.op.type == "Variable" || - variable.op.type == "VariableV2" || - variable.op.type == "AutoReloadVariable") + if (variable.dtype.is_ref_dtype()) yield return new ReferenceVariableSaveable(variable, "", name); else yield return new ResourceVariableSaveable(variable, "", name); @@ -102,7 +100,7 @@ namespace Tensorflow if (convert_variable_to_tensor) { - if (var is ResourceVariable) + if (!var.dtype.is_ref_dtype()) tensor = var.GraphElement; else tensor = ops.internal_convert_to_tensor(var, as_ref: true); diff --git a/src/TensorFlowNET.Core/Training/gen_training_ops.cs b/src/TensorFlowNET.Core/Training/gen_training_ops.cs index 36eca7d6..c141d59e 100644 --- a/src/TensorFlowNET.Core/Training/gen_training_ops.cs +++ b/src/TensorFlowNET.Core/Training/gen_training_ops.cs @@ -41,7 +41,7 @@ namespace Tensorflow throw new NotImplementedException(""); } - public static Tensor apply_adam(IVariableV1 var, IVariableV1 m, IVariableV1 v, Tensor beta1_power, Tensor beta2_power, + public static Tensor apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, bool use_locking = false, bool use_nesterov = false, string name = null) { @@ -64,7 +64,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor apply_gradient_descent(RefVariable var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) + public static Tensor apply_gradient_descent(IVariableV1 var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) { var _op = tf.OpDefLib._apply_op_helper("ApplyGradientDescent", name, new { @@ -82,7 +82,7 @@ namespace Tensorflow if (tf.executing_eagerly()) { var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ResourceApplyGradientDescent", name, + "ResourceApplyGradientDescent", name, null, var, alpha, delta, "use_locking", use_locking); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 9b179b4e..18b93ec3 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -28,6 +28,8 @@ namespace Tensorflow protected Tensor _initial_value; public Tensor initial_value => _initial_value; + public Operation initializer => initializer_op; + protected Tensor _parent_op; public Tensor parent_op => _parent_op; @@ -73,6 +75,14 @@ namespace Tensorflow public ITensorOrOperation assign(T value, bool use_locking = false, string name = null, bool read_value = true) { + if(value.GetType() == typeof(Tensor)) + { + var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name); + if (read_value) + return assign; + return assign.op; + } + var value_tensor = ops.convert_to_tensor(value, dtype: dtype); var assign_op = gen_resource_variable_ops.assign_variable_op( handle, value_tensor, name: name); @@ -82,7 +92,7 @@ namespace Tensorflow return assign_op; } - public Tensor value() => _read_variable_op(); + public Tensor value() => tf.executing_eagerly() ? _read_variable_op() : GraphElement; protected Tensor _read_variable_op() { @@ -149,6 +159,7 @@ namespace Tensorflow { } - public Tensor AsTensor() => read_value(); + public Tensor AsTensor() + => tf.executing_eagerly() ? read_value() : GraphElement; } } diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index 9178d6ad..cd76b092 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -33,10 +33,16 @@ namespace Tensorflow { public string UniqueId { get; } public string Name { get; } + /// + /// Handle is ref type + /// public Tensor Handle { get; } public string Device { get; } public Operation Initializer { get; } public Operation Op { get; } + /// + /// GraphElement is a copy of Handle + /// public Tensor GraphElement { get; } public Graph Graph { get; } public TF_DataType dtype { get; } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs index 7f91340b..d8a743dc 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs @@ -1,5 +1,6 @@ using System; using Tensorflow.Eager; +using static Tensorflow.Binding; namespace Tensorflow { @@ -21,11 +22,6 @@ namespace Tensorflow public static implicit operator EagerTensor(ResourceVariable var) => var._dense_var_to_tensor() as EagerTensor; - public static implicit operator RefVariable(ResourceVariable var) - { - return null; - } - public static implicit operator IntPtr(ResourceVariable var) => var._handle; @@ -35,5 +31,13 @@ namespace Tensorflow { return value(); } + + public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + { + if (as_ref) + return handle; + else + return tf.executing_eagerly() ? AsTensor() : value(); + } } } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index ea2ee42a..d42eb3dd 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -49,6 +49,7 @@ namespace Tensorflow VariableDef variable_def = null, TF_DataType dtype = TF_DataType.DtInvalid, string import_scope = "", + VariableAggregation aggregation = VariableAggregation.None, TensorShape shape = null) { if (variable_def != null) @@ -65,6 +66,7 @@ namespace Tensorflow caching_device: caching_device, name: name, dtype: dtype, + aggregation: aggregation, shape: shape); } } @@ -75,6 +77,7 @@ namespace Tensorflow string caching_device = "", string name = null, TF_DataType dtype = TF_DataType.DtInvalid, + VariableAggregation aggregation = VariableAggregation.None, TensorShape shape = null) { var init_from_fn = initial_value.GetType().Name == "Func`1" || @@ -114,55 +117,43 @@ namespace Tensorflow if (initial_value.GetType().GetInterface("IInitializer") != null) initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); else - initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func)() : initial_value, + { + var value = init_from_fn ? (initial_value as Func)() : initial_value; + initial_value = ops.convert_to_tensor(value, name: "initial_value", dtype: dtype); + } }); _shape = shape ?? (initial_value as Tensor).TensorShape; _initial_value = initial_value as Tensor; - handle = resource_variable_ops.eager_safe_variable_handle( - initial_value: _initial_value, - shape: _shape, - shared_name: shared_name, - name: name, - graph_mode: _in_graph_mode); + - _dtype = _initial_value.dtype.as_base_dtype(); if (_in_graph_mode) { - tf_with(ops.name_scope("IsInitialized"), delegate - { - is_initialized_op = gen_resource_variable_ops.var_is_initialized_op(handle); - }); - - if(initial_value != null) - { - tf_with(ops.name_scope("Assign"), scope1 => - { - string n = scope1; - var _initial_value2 = variables._try_guard_against_uninitialized_dependencies(name, _initial_value); - initializer_op = gen_resource_variable_ops.assign_variable_op(handle, _initial_value2, name: n); - }); - } + handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); + initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; - // Manually assign reads to the handle's device to avoid log - // messages. - tf_with(ops.name_scope("Read"), delegate - { - var value = gen_resource_variable_ops.read_variable_op(handle, _dtype); - // _maybe_set_handle_data(dtype, handle, value); - _graph_element = value; - }); + ops.colocate_with(initializer_op); + _graph_element = gen_array_ops.identity(handle, name = "read"); ops.add_to_collections(collections, this); + _dtype = handle.dtype; } else { + handle = resource_variable_ops.eager_safe_variable_handle( + initial_value: _initial_value, + shape: _shape, + shared_name: shared_name, + name: name, + graph_mode: _in_graph_mode); + gen_resource_variable_ops.assign_variable_op(handle, _initial_value); is_initialized_op = null; initializer_op = null; _graph_element = null; + _dtype = _initial_value.dtype.as_base_dtype(); initial_value = _in_graph_mode ? initial_value : null; } @@ -237,5 +228,23 @@ namespace Tensorflow return array_ops.identity(value); }); } + + public VariableDef to_proto(string export_scope) + { + if (string.IsNullOrEmpty(export_scope) || Handle.name.StartsWith(export_scope)) + { + var var_def = new VariableDef(); + var_def.VariableName = ops.strip_name_scope(Handle.name, export_scope); + if (_initial_value != null) + var_def.InitialValueName = ops.strip_name_scope(_initial_value.name, export_scope); + var_def.Trainable = _trainable; + var_def.InitializerName = ops.strip_name_scope(initializer.name, export_scope); + var_def.SnapshotName = ops.strip_name_scope(_graph_element.name, export_scope); + + return var_def; + } + + throw new NotImplementedException("to_proto RefVariable"); + } } } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 513e8d07..f42e49bd 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -467,7 +467,7 @@ namespace Tensorflow case RefVariable varVal: return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); case ResourceVariable varVal: - return varVal.value(); + return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); case TensorShape ts: return constant_op.constant(ts.dims, dtype: dtype, name: name); case int[] dims: diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 368c66ba..94c9f49f 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -70,12 +70,14 @@ namespace Tensorflow bool use_resource = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, + VariableAggregation aggregation = VariableAggregation.None, int[] shape = null) => new ResourceVariable(data, trainable: trainable, validate_shape: validate_shape, name: name, dtype: dtype, + aggregation: aggregation, shape: shape); public Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)