From 9420ba3243604722c1920ebe7664bb4ca78562c0 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sun, 16 Apr 2023 00:58:23 +0800 Subject: [PATCH] Fix the error of loaded function model backward. --- .../Contexts/FunctionCallOptions.cs | 1 + .../Eager/EagerRunner.MustRecordGradient.cs | 32 +- .../Eager/EagerRunner.RecordGradient.cs | 19 +- .../Eager/EagerRunner.TFE_TapeGradient.cs | 179 +++++++++-- .../EagerRunner.TapeSetRecordBackprop.cs | 7 +- .../EagerRunner.TapeSetRecordOperation.cs | 16 +- src/TensorFlowNET.Core/Eager/IEagerRunner.cs | 9 +- .../Functions/ConcreteFunction.cs | 25 +- .../Functions/EagerDefinedFunction.cs | 19 +- .../FirstOrderTapeGradientFunctions.cs | 9 +- src/TensorFlowNET.Core/Functions/Function.cs | 8 +- .../Functions/TapeGradientFunctions.cs | 157 ++++++---- .../Functions/TracingCompiler.cs | 12 +- .../Functions/monomorphic_function.cs | 2 +- .../Gradients/BackpropInitialState.cs | 4 +- .../Gradients/GradientTape.cs | 35 ++- src/TensorFlowNET.Core/Gradients/ITape.cs | 23 +- .../Gradients/OpTapeEntry.cs | 4 +- .../Gradients/Tape.ComputeGradient.cs | 282 ++++++++++-------- .../Gradients/Tape.PrepareBackprop.cs | 63 ++-- .../Gradients/Tape.RecordOperation.cs | 31 +- src/TensorFlowNET.Core/Gradients/Tape.cs | 20 +- .../Gradients/TapeTensor.cs | 54 +++- .../Gradients/TensorTape.cs | 2 +- .../Gradients/gradients_util.cs | 27 +- src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 10 + .../Keras/Engine/IOptimizer.cs | 4 +- .../Operations/c_api.ops.cs | 8 +- .../Operations/functional_ops.cs | 4 +- .../Operations/gen_array_ops.cs | 47 ++- .../Operations/handle_data_util.cs | 6 +- .../Operations/resource_variable_ops.cs | 14 + .../Tensorflow.Binding.csproj | 2 +- .../Tensors/Tensor.Creation.cs | 6 +- .../SavedModel/function_deserialization.cs | 1 + src/TensorFlowNET.Core/Util/UnorderedMap.cs | 13 + .../Variables/BaseResourceVariable.cs | 5 + src/TensorFlowNET.Core/ops.cs | 8 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 22 +- src/TensorFlowNET.Keras/Engine/Model.Train.cs | 6 +- .../Optimizers/OptimizerV2.cs | 8 +- .../Saving/KerasObjectLoader.cs | 25 ++ .../Saving/SavedModel/RevivedLayer.cs | 22 +- .../SavedModel/serialized_attributes.cs | 2 +- .../SaveModel/SequentialModelLoad.cs | 26 +- 45 files changed, 870 insertions(+), 409 deletions(-) diff --git a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs index 2fcf9dce..71312d11 100644 --- a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs +++ b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Google.Protobuf; +using Protobuf.Text; using static Tensorflow.Binding; namespace Tensorflow.Contexts diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs index c4bce84f..33382703 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs @@ -12,18 +12,36 @@ namespace Tensorflow.Eager return HasGradientTape(); } - private bool ShouldRecord(Tensor[] inputs) + public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors) { - bool should_record = false; - foreach (var tape in tf.GetTapeSet()) + var tape_set = tf.GetTapeSet(); + var input_ids = MakeTensorIDList(tensors); + var input_dtypes = MakeTensorDtypeList(tensors); + bool some_tape_watching = false; + if (tape_set is not null && tape_set.Count > 0) { - if (tape.ShouldRecord(inputs)) + foreach (var tape in tape_set) { - should_record = true; - break; + if (tape.ShouldRecord(input_ids, input_dtypes)) + { + if (tape.Persistent || some_tape_watching) + { + return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; + } + some_tape_watching = true; + } } } - return should_record; + // skip the forward_accumulators. + + if (some_tape_watching) + { + return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; + } + else + { + return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE; + } } } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index cfcea55a..59d5fd03 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -13,7 +13,17 @@ namespace Tensorflow.Eager Tensor[] results, BackwardFunction backwardFunction = null) { - bool should_record = ShouldRecord(inputs); + var input_ids = MakeTensorIDList(inputs); + var input_dtypes = MakeTensorDtypeList(inputs); + bool should_record = false; + foreach (var tape in tf.GetTapeSet()) + { + if (tape.ShouldRecord(input_ids, input_dtypes)) + { + should_record = true; + break; + } + } if (!should_record) { @@ -59,7 +69,7 @@ namespace Tensorflow.Eager op_inputs = inputs;*/ backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); - TapeSetRecordOperation(op_name, inputs, results, backwardFunction); + TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction); return true; } @@ -129,10 +139,5 @@ namespace Tensorflow.Eager { return HasGradientTape(); } - - TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) - { - return tensors.Select(x => x.dtype).ToArray(); - } } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs index c96d09e5..1f7b3ae6 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs @@ -1,6 +1,8 @@ -using System; +using OneOf.Types; +using System; using Tensorflow.Gradients; using Tensorflow.Util; +using static Tensorflow.Binding; namespace Tensorflow.Eager { @@ -9,40 +11,183 @@ namespace Tensorflow.Eager /// public partial class EagerRunner { + /// + /// + /// + /// + /// + /// + /// + /// determines the value returned if the target and + /// sources are unconnected.When 'none' the value returned is None wheras when + /// 'zero' a zero tensor in the same shape as the sources is returned. + /// + /// public Tensor[] TFE_TapeGradient(ITape tape, Tensor[] target, Tensor[] sources, - Tensor[] output_gradients) + List output_gradients, + Tensor[] sources_raw, + string unconnected_gradients) { - var target_vec = target; - var sources_vec = sources; - var sources_set = sources_vec; + if (!tape.Persistent) + { + var tape_set = tf.GetTapeSet(); + if (tape_set.Contains(tape)) + { + throw new RuntimeError("gradient() cannot be invoked within the " + + "GradientTape context (i.e., while operations are being " + + "recorded). Either move the call to gradient() to be " + + "outside the 'with tf.GradientTape' block, or " + + "use a persistent tape: " + + "'with tf.GradientTape(persistent=true)'"); + } + } + + var target_vec = MakeTensorIDList(target); + var sources_vec = MakeTensorIDList(sources); + HashSet sources_set = new HashSet(sources_vec); + var source_tensors_that_are_targets = new UnorderedMap(); + + int len = target.Length; + for(int i = 0; i < len; i++) + { + var target_id = target_vec[i]; + if (sources_set.Contains(target_id)) + { + var tensor = target[i]; + source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor); + } + } + + List outgrad_vec = new(); + if(output_gradients is not null) + { + outgrad_vec = output_gradients.ToList(); + } + var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); - var seq_array = target; - var source_tensors_that_are_targets = new UnorderedMap(); - for (int i = 0; i < target.Length; ++i) + bool unconnected_gradients_zero = unconnected_gradients == "zero"; + Tensor[] sources_obj = null; + if (unconnected_gradients_zero) { - source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i])); + sources_obj = MakeTensorList(sources_raw); } - if (output_gradients != null) + if (result.Length > 0) { - throw new NotImplementedException(""); + for(int i = 0; i < result.Length; i++) + { + if (result[i] is null && unconnected_gradients_zero) + { + var dtype = sources_obj[i].dtype; + result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike(); + } + } } - else + return result; + } + + Tensor[] MakeTensorList(IEnumerable tensors) + { + return tensors.ToArray(); + } + + long[] MakeTensorIDList(Tensor[] tensors) + { + int len = tensors.Length; + long[] ids = new long[len]; + for(int i = 0; i < len; i++) + { + var tensor = tensors[i]; + ids[i] = tensor.Id; + } + return ids; + } + + TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) + { + int len = tensors.Length; + TF_DataType[] dtypes = new TF_DataType[len]; + for (int i = 0; i < len; i++) { - output_gradients = new Tensor[0]; + var tensor = tensors[i]; + dtypes[i] = tensor.dtype; } + return dtypes; + } - var outgrad_vec = MakeTensorList(output_gradients); + TapeTensor TapeTensorFromTensor(Tensor tensor) + { + long id = tensor.Id; + var dtype = tensor.dtype; + if (tensor is EagerTensor) + { + var handle = tensor.EagerTensorHandle; + if (DTypeNeedsHandleData(dtype)) + { + return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor); + } + + Status status = new(); + int num_dims = c_api.TFE_TensorHandleNumDims(handle, status); + long[] dims = new long[num_dims]; + for(int i = 0; i < num_dims; i++) + { + dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); + } + Shape tensor_shape = new(dims); + + if(status.Code != TF_Code.TF_OK) + { + return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null); + } + else + { + return new TapeTensor(id, dtype, tensor_shape); + } + } + var shape_tuple = tensor.shape.dims; + if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype)) + { + return new TapeTensor(id, dtype, tensor); + } + long[] l = new long[shape_tuple.Length]; + for(int i = 0; i < shape_tuple.Length; i++) + { + if (shape_tuple[i] < 0) + { + l[i] = 0; + } + else + { + l[i] = shape_tuple[i]; + } + } + return new TapeTensor(id, dtype, new Shape(l)); + } - return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec); + bool DTypeNeedsHandleData(TF_DataType dtype) + { + return dtype == dtypes.variant || dtype == dtypes.resource; } - Tensor[] MakeTensorList(Tensor[] tensors) + bool ListContainNone(long[] list) { - return tensors; + int len = list.Length; + if(len == 0) + { + return true; + } + for(int i = 0; i < len; i++) + { + if (list[i] == -1) + { + return true; + } + } + return false; } } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs index e8751aed..9bcc8fe2 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs @@ -7,8 +7,9 @@ namespace Tensorflow.Eager public partial class EagerRunner { void TapeSetRecordBackprop(string op_type, - Tensor[] input_tensors, - TapeTensor[] output_tensors, + TapeTensor[] output_info, + long[] input_ids, + TF_DataType[] input_detyps, BackwardFunction backward_function) { if (!CouldBackprop()) @@ -18,7 +19,7 @@ namespace Tensorflow.Eager foreach (var tape in tf.GetTapeSet()) { - tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function); + tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function); } } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs index 42e1cff9..3987e7a3 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs @@ -10,18 +10,28 @@ namespace Tensorflow.Eager public bool TapeSetRecordOperation(string op_type, Tensor[] input_tensors, Tensor[] output_tensors, + long[] input_ids, + TF_DataType[] input_dtypes, BackwardFunction backward_function) { - var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray(); - + var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray(); if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, backward_function)) return false; - TapeSetRecordBackprop(op_type, input_tensors, output_info, + TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes, backward_function); return true; } + + public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, + Tensor[] input_tensors, BackwardFunction backward_function) + { + var input_ids = MakeTensorIDList(input_tensors); + var input_dtypes = MakeTensorDtypeList(input_tensors); + TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes, + backward_function); + } } } diff --git a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs index 7baf4cd7..21a33669 100644 --- a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs +++ b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs @@ -29,7 +29,14 @@ namespace Tensorflow.Eager Tensor[] TFE_TapeGradient(ITape tape, Tensor[] target, Tensor[] sources, - Tensor[] output_gradients); + List output_gradients, + Tensor[] sources_raw, + string unconnected_gradients); + + void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, + Tensor[] input_tensors, BackwardFunction backward_function); + + int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors); bool RecordGradient(string op_name, Tensor[] inputs, diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index fbebd4d6..5c2d3a8d 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -18,12 +18,13 @@ namespace Tensorflow.Functions public class ConcreteFunction: Trackable { protected IEnumerable _captured_inputs; - internal FuncGraph func_graph; protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; protected Dictionary _attrs; protected FunctionSpec _function_spec; protected FunctionSpec _pre_initialized_function_spec = null; protected EagerDefinedFunction _inference_function; + protected Dictionary _tape_functions_cache = new(); + internal FuncGraph func_graph; internal ForwardBackwardCall forward_backward; public Tensor[] Inputs => func_graph.Inputs; public Tensor[] CapturedInputs => func_graph.external_captures; @@ -156,6 +157,17 @@ namespace Tensorflow.Functions { var executing_eagerly = tf.Context.executing_eagerly(); var default_graph = ops.get_default_graph(); + // TODO(Rinne): deal with `default_graph.building_function` + + var tempvv = func_graph.Variables; + if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph) + { + foreach(var v in this.func_graph.Variables) + { + resource_variable_ops.variable_accessed(v); + } + } + var tensor_inputs = new Tensors(); foreach (var (i, arg) in enumerate(args)) { @@ -223,11 +235,16 @@ namespace Tensorflow.Functions { input_tangents = new TangentInfo(); } - if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient()) + if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) { if(input_tangents.Indices is not null || executing_eagerly) { - var functions = new FirstOrderTapeGradientFunctions(func_graph, false); + string cache_key = "first_order"; + if(!_tape_functions_cache.TryGetValue(cache_key, out var functions)) + { + functions = new FirstOrderTapeGradientFunctions(func_graph, false); + _tape_functions_cache[cache_key] = functions; + } return new ForwardBackwardCall(functions, args, tape_watching: true); } else @@ -241,7 +258,7 @@ namespace Tensorflow.Functions } // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. - return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient()); + return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); } internal void set_variables(IEnumerable variables) diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs index c2f8e016..cc38683d 100644 --- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -124,17 +124,16 @@ namespace Tensorflow.Functions // TODO(Rinne): Add arg `CancellationManager`. // TODO(Rinne): Check the arg length. var function_call_options = tf.Context.FunctionCallOptions; - string config; - if (function_call_options.config_proto_serialized().Length == 0) - { - config = function_utils.get_disabled_rewriter_config().ToString(); - } - else - { - config = function_call_options.config_proto_serialized().ToString(); - } + string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons. - config = ""; // TODO(Rinne): revise it. + //if (function_call_options.config_proto_serialized().Length == 0) + //{ + // config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); + //} + //else + //{ + // config = function_call_options.config_proto_serialized().ToStringUtf8(); + //} string executor_type = function_call_options.ExecutorType ?? ""; var executing_eagerly = tf.Context.executing_eagerly(); diff --git a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs index c0e69dba..bfb0defc 100644 --- a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs @@ -14,12 +14,11 @@ namespace Tensorflow.Functions } - public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) + public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int) + ForwardAndBackwardFunctions(Tensors inference_args) { - var outputs = _func_graph.Outputs; - (_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs) - = BuildFunctionsForOutputs(outputs, inference_args); - return _forward_function; + var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray(); + return BuildFunctionsForOutputs(outputs, inference_args); } } } diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index a53df14c..ea1b3eec 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -14,7 +14,6 @@ namespace Tensorflow protected ConcreteFunction _concrete_variable_creation_fn; protected bool _autograph; protected TracingCompiler _variable_creation_fn; - protected bool _has_initialized; public string Name { get; set; } public Function(Func csharp_function, string name, bool auto_graph = true) @@ -22,7 +21,6 @@ namespace Tensorflow _csharp_function = csharp_function; Name = name; _autograph = auto_graph; - _has_initialized = false; } public virtual Tensors Apply(Tensors inputs) @@ -38,10 +36,11 @@ namespace Tensorflow protected virtual Tensors _call(Tensors inputs) { - if (!_has_initialized) + if(_variable_creation_fn is not null) { - _initialize(inputs); + return _variable_creation_fn.Apply(inputs); } + _initialize(inputs); return _concrete_variable_creation_fn.CallFlat(inputs, _concrete_variable_creation_fn.CapturedInputs); @@ -63,7 +62,6 @@ namespace Tensorflow _variable_creation_fn = _compiler(_csharp_function); _variable_creation_fn._name = this.Name; _concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); - _has_initialized = true; } } } diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 638aeaf1..3895226e 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -24,23 +24,40 @@ namespace Tensorflow.Functions protected string _INFERENCE_PREFIX = "__inference_"; protected FuncGraph _func_graph; - protected EagerDefinedFunction _forward_function; + protected EagerDefinedFunction _forward; protected FuncGraph _forward_graph; + protected List _forwardprop_input_indices; protected List _forwardprop_output_indices; protected int _num_forwardprop_outputs; - protected ConcreteFunction _backward_function; + protected int _num_inference_outputs; + protected int _num_outputs; + protected int _num_trainable_inference_outputs; + protected ConcreteFunction _backward; BackwardFunction _backward_function_wrapper; public TapeGradientFunctions(FuncGraph func_graph, bool need_gradients_for_jvps) { _func_graph = func_graph; + _forward_graph = null; + _forward = null; + _backward = null; + _num_outputs = func_graph.Outputs.Length; + _forwardprop_output_indices = null; + _num_forwardprop_outputs = 0; + _num_inference_outputs = func_graph.Outputs.Length; + _num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count(); } public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) { // TODO(Rinne): add input_tangents arg. - return ForwardAndBackwardFunctions(inference_args); + if(_forward is null) + { + (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) + = ForwardAndBackwardFunctions(inference_args); + } + return _forward; } /// @@ -51,9 +68,13 @@ namespace Tensorflow.Functions public virtual void Record(Tensors flat_outputs, Tensors inference_args) { // TODO(Rinne): add arg `input_tagents`. - var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs); - tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record, - getBackwardFunction: backward_function); + var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); + if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0) + { + // TODO(Rinne): implement it. + throw new NotImplementedException(); + } + tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function); } /// @@ -65,66 +86,95 @@ namespace Tensorflow.Functions /// (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) { + var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs) + .ToDictionary(x => x.Item1, x => x.Item2); + var captured_inputs = backward.CapturedInputs; + var remapped_captures = captured_inputs.Select(c => + { + if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value)) + { + return value; + } + else + { + return c; + } + }).ToArray(); + if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph)) + { + var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph); + throw new RuntimeError($"Failed to map all backward graph captures to " + + $"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}"); + } + + Dictionary variant_zeros_like = new Dictionary(); var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; var recorded_outputs = new Tensors(); - var trainable_recorded_outputs = 0; - foreach (var (output_index, output) in enumerate(outputs)) + int trainable_recorded_outputs = 0; + var skip_positions = new HashSet(); + var relevant_outputs = outputs; + foreach (var (output_index, output) in enumerate(relevant_outputs)) { if (trainable_recorded_outputs < backward_function_inputs) recorded_outputs.Add(output); - if (gradients_util.IsTrainable(output)) - trainable_recorded_outputs += 1; + if (backprop_util.IsTrainable(output)) + trainable_recorded_outputs++; + else + skip_positions.Add(output_index); + if (output.dtype == dtypes.variant) + variant_zeros_like[output_index] = default_gradient.zeros_like(output); } - if(_backward_function_wrapper == null) + _backward_function_wrapper = (args, unneeded_gradients) => { - var capture_mapping = new Dictionary(); - foreach (var (i, output) in enumerate(outputs)) - capture_mapping[forward_graph.Outputs[i].Id] = output; - - var remapped_captures = new Tensors(); - foreach (var capture in backward.CapturedInputs) - { - if (capture_mapping.ContainsKey(capture.Id)) - remapped_captures.Add(capture_mapping[capture.Id]); - } - - var skip_positions = new List(); - foreach (var (output_index, output) in enumerate(outputs)) + if(backward.Outputs is null || backward.Outputs.Length == 0) { - if (!gradients_util.IsTrainable(output)) - skip_positions.Add(output_index); + return backward.FlatStructuredOutputs; } - _backward_function_wrapper = (args, unneeded_gradients) => + var processed_args = new Tensors(); + int input_index = 0; + foreach (var (output_index, arg) in enumerate(args)) { - var processed_args = new Tensors(); - var input_index = 0; - foreach (var (output_index, arg) in enumerate(args)) + if (skip_positions.Contains(output_index)) + continue; + if (arg is null) + { + var input_placeholder = backward.Inputs[input_index]; + Tensor variant_arg; + if (input_placeholder.dtype == dtypes.variant) + { + variant_arg = variant_zeros_like[output_index]; + } + else + { + var (shape, type) = default_gradient.shape_and_dtype(input_placeholder); + + variant_arg = array_ops.zeros(shape, type); + } + processed_args.Add(variant_arg); + } + else { - if (skip_positions.Contains(output_index)) - continue; - if (arg == null) - throw new NotImplementedException(""); processed_args.Add(arg); - input_index += 1; - if (input_index >= backward_function_inputs) - break; } + input_index++; + if (input_index >= backward_function_inputs) + break; + } - tf.Logger.Debug($"Invoke backward function: {backward.Name}"); - var gradients = backward.CallFlat(processed_args, remapped_captures); + tf.Logger.Debug($"Invoke backward function: {backward.Name}"); + var gradients = backward.CallFlat(processed_args, remapped_captures); - foreach (var unneeded_gradient_index in unneeded_gradients) - { - var index = Convert.ToInt32(unneeded_gradient_index); - if (gradients.Length <= index) - gradients.Insert(index, null); - } + foreach (var unneeded_gradient_index in unneeded_gradients) + { + var index = Convert.ToInt32(unneeded_gradient_index); + if (gradients.Length <= index) + gradients.Insert(index, null); + } - return gradients; - }; - } + return gradients; + }; return (_backward_function_wrapper, recorded_outputs); } @@ -143,7 +193,7 @@ namespace Tensorflow.Functions } } - var backwards_graph = new FuncGraph(_func_graph.Name); + var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name)); backwards_graph.as_default(); var gradients_wrt_outputs = new List(); foreach (var output in trainable_outputs) @@ -153,6 +203,7 @@ namespace Tensorflow.Functions gradients_wrt_outputs.Add(gradient_placeholder); handle_data_util.copy_handle_data(output, gradient_placeholder); } + // TODO(Rinne): with ops.device(None) var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs, grad_ys: gradients_wrt_outputs.ToArray(), @@ -175,7 +226,8 @@ namespace Tensorflow.Functions backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); - var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); + var (wrapped_forward_function, wrapped_backward_function) = + monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); //var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; //var backward_function_attr = new Dictionary(); //backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; @@ -189,10 +241,11 @@ namespace Tensorflow.Functions // _func_graph.Inputs, _func_graph.Outputs, // monomorphic_function_utils._parse_func_attrs(forward_function_attr)); - return (forward_function, _func_graph, backward_function, null, 0); + return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0); } - public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) + public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int) + ForwardAndBackwardFunctions(Tensors inference_args) { throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/Functions/TracingCompiler.cs b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs index 8a844671..fb109595 100644 --- a/src/TensorFlowNET.Core/Functions/TracingCompiler.cs +++ b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs @@ -73,12 +73,12 @@ namespace Tensorflow.Functions private static string male_cache_key(Tensor[] inputs) { - string res = ""; - foreach (var input in inputs) - { - res += $"{input.name}_{input.Id}"; - } - return res; + //string res = ""; + //foreach (var input in inputs) + //{ + // res += $"{input.name}_{input.Id}"; + //} + return inputs.Length.ToString(); } } } diff --git a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs index acf00597..7cb5c705 100644 --- a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs +++ b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs @@ -153,7 +153,7 @@ namespace Tensorflow.Functions foreach(var tape in tf.GetTapeSet()) { tape.RecordOperation(_inference_function.Signature.Name, to_record, - inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function); + inference_args, backward_function); } } diff --git a/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs b/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs index eee98a61..743ed0d8 100644 --- a/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs +++ b/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs @@ -9,7 +9,7 @@ namespace Tensorflow.Gradients /// Map from tensor to how many references still exist for this tensor in /// the tape. /// - public UnorderedMap tensor_usage_counts { get; set; } + public UnorderedMap tensor_usage_counts { get; set; } /// /// Maps from op ID to how many output tensors of this op still need to have /// their gradients computed. @@ -19,7 +19,7 @@ namespace Tensorflow.Gradients public BackpropInitialState() { op_tape = new OpTape(); - tensor_usage_counts = new UnorderedMap(); + tensor_usage_counts = new UnorderedMap(); op_missing_tensor = new UnorderedMap(); } } diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs index 31517e58..b5fd373e 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs @@ -67,40 +67,59 @@ namespace Tensorflow.Gradients /// /// /// - public Tensor gradient(Tensor target, Tensor source) + public Tensor gradient(Tensor target, Tensor source, List output_gradients = null, + string unconnected_gradients = null) { + if(_tape is null) + { + throw new RuntimeError("A non-persistent GradientTape can only be used to " + + "compute one set of gradients (or jacobians)."); + } + ITape tape = stop_recording(); var results = tf.Runner.TFE_TapeGradient(tape, new[] { target }, new[] { source }, - null); + output_gradients, + new[] { source }, + unconnected_gradients); return results[0]; } - public Tensor gradient(Tensor target, ResourceVariable source) + public Tensor gradient(Tensor target, ResourceVariable source, List output_gradients = null, + string unconnected_gradients = null) { - var results = gradient(target, new List { source }); + var results = gradient(target, new List { source }, output_gradients, unconnected_gradients); return results[0]; } - public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) + public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List output_gradients = null, + string unconnected_gradients = null) { - var results = gradient(target, new List { sources.Item1, sources.Item2 }); + var results = gradient(target, new List { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients); return (results[0], results[1]); } - public Tensor[] gradient(Tensor target, IEnumerable sources) + public Tensor[] gradient(Tensor target, IEnumerable sources, List output_gradients = null, + string unconnected_gradients = null) { + if (_tape is null) + { + throw new RuntimeError("A non-persistent GradientTape can only be used to " + + "compute one set of gradients (or jacobians)."); + } var tape = stop_recording(); var results = tf.Runner.TFE_TapeGradient(tape, new[] { target }, sources.Select(x => x.Handle).ToArray(), - null); + output_gradients, + sources.Select(x => x.Handle).ToArray(), + unconnected_gradients); if (!tape.Persistent) { diff --git a/src/TensorFlowNET.Core/Gradients/ITape.cs b/src/TensorFlowNET.Core/Gradients/ITape.cs index dbd085ea..07594dab 100644 --- a/src/TensorFlowNET.Core/Gradients/ITape.cs +++ b/src/TensorFlowNET.Core/Gradients/ITape.cs @@ -6,24 +6,31 @@ namespace Tensorflow.Gradients public interface ITape { void SetTapeId(int id); - bool ShouldRecord(Tensor[] tensors); + bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes); void StartRecord(); void StopRecord(); bool Persistent { get; } void RecordOperation(string op_type, - Tensor[] input_tensors, TapeTensor[] output_tensors, + long[] input_tensor_id, + TF_DataType[] input_dtypes, BackwardFunction backward_function); - void VariableAccessed(ResourceVariable variable); + void RecordOperation(string op_type, + Tensor[] outputs, + Tensor[] inputs, + BackwardFunction backward_function); + + void VariableAccessed(IVariableV1 variable); void Watch(Tensor x); - ResourceVariable[] WatchedVariables(); + IVariableV1[] WatchedVariables(); - Tensor[] ComputeGradient(Tensor[] target_tensor_ids, - Tensor[] source_tensor_ids, - UnorderedMap sources_that_are_targets, - Tensor[] output_gradients); + Tensor[] ComputeGradient(long[] target_tensor_ids, + long[] source_tensor_ids, + UnorderedMap sources_that_are_targets, + List output_gradients, + bool build_default_zeros_grads); } } diff --git a/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs index 537369dd..7665fa01 100644 --- a/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs +++ b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs @@ -9,9 +9,9 @@ namespace Tensorflow.Gradients { public string op_type { get; set; } public TapeTensor[] output_tensor_info { get; set; } - public Tensor[] input_tensor_id { get; set; } + public long[] input_tensor_id { get; set; } public BackwardFunction backward_function { get; set; } public override string ToString() - => $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}"; + => $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; } } diff --git a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs index 73c9e501..8a4a41f6 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs @@ -2,235 +2,246 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Util; +using static Tensorflow.Binding; namespace Tensorflow.Gradients { public partial class Tape { - // int kMinAggregateCount = 4; - // int kMinAggregateBytes = 128 * 1024 * 1024; + static readonly int kMinAggregateCount = 4; + static readonly int kMinAggregateBytes = 128 * 1024 * 1024; + private static UnorderedMap> _functionsAcceptingNoneForIndicesMap; - public Tensor[] ComputeGradient(Tensor[] target_tensor_ids, - Tensor[] source_tensor_ids, - UnorderedMap sources_that_are_targets, - Tensor[] output_gradients) + static Tape() { - var sources_set = new UnorderedSet(source_tensor_ids); - // var gradients_size = new UnorderedMap(); - var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap(); - var state = PrepareBackprop( - target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent); - var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); - var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, - output_gradients, - tensor_tape_, - state.op_tape); + _functionsAcceptingNoneForIndicesMap = new(); + _functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 })); + _functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 })); + _functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet(new[] { 1, 2, 3, 4 })); + } - while (!op_stack.empty()) + public Tensor[] ComputeGradient(long[] target_tensor_ids, + long[] source_tensor_ids, + UnorderedMap sources_that_are_targets, + List output_gradients, + bool build_default_zeros_grads) + { + UnorderedSet sources_set = new(source_tensor_ids); + BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent); + var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); + var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape); + UnorderedMap gradients_size = new(); + while(op_stack.Count > 0) { - var op = op_stack.Dequeue(); - if (!state.op_tape.find(op, out var trace)) + long op = op_stack.Dequeue(); + if(!state.op_tape.TryGetValue(op, out var op_it)) + { continue; - - // Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); + } + var trace = op_it; state.op_tape.erase(op); - - var out_gradients = new List(trace.output_tensor_info.Length); - var unneeded_gradients = new List(); - for (int i = 0; i < trace.input_tensor_id.Length; i++) + List out_gradients = new(); + List unneeded_gradients = new(); + for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++) { - var in_tensor_id = trace.input_tensor_id[i]; - if (!tensor_tape_.find(in_tensor_id) && - !sources_set.find(in_tensor_id)) + long in_tensor_id = trace.input_tensor_id[i]; + if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id)) + { unneeded_gradients.Add(i); + } } bool any_gradient_nonzero = false; - var zero_indices = new List(); - for (int i = 0; i < trace.output_tensor_info.Length; ++i) + List zero_indices = new(); + for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++) { - var id = trace.output_tensor_info[i].GetTensor(); - if (!gradients.find(id, out var grad_it)) + long id = trace.output_tensor_info[i].GetID(); + if(!gradients.TryGetValue(id, out var grad_it)) { - if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) && - func_name_it.find(i)) + out_gradients.Add(null); + if (build_default_zeros_grads) { - out_gradients.Add(null); - } - else - { - out_gradients.Add(null); - zero_indices.Add(i); + if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) || + !func_name_it.find(i)) + { + zero_indices.Add(i); + } } } else { any_gradient_nonzero = true; - var new_gradients = grad_it.Count == 1 ? - grad_it[0] : - gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients - + Tensor new_gradients; + if (grad_it.Count == 1) + { + new_gradients = grad_it[0]; + } + else + { + new_gradients = AggregateGradients(grad_it); + } if (!sources_set.find(id)) + { gradients.Remove(id); + } else { - // grad_it.Clear(); - // grad_it.Add(new_gradients); - // vspace.MarkAsResult(new_gradients); + grad_it.Clear(); + grad_it.Add(new_gradients); + // MarkAsResult } out_gradients.Add(new_gradients); } } - Tensor[] in_gradients; + Tensor[] in_gradients = new Tensor[0]; if (any_gradient_nonzero) { - // foreach (var i in zero_indices) - // out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); - - in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray()); - - if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length) - throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); - if (!_persistent) + foreach(var i in zero_indices) { - // trace.backward_function_deleter(trace.backward_function); - trace.backward_function = null; + out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); } + in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients); } else { - in_gradients = new Tensor[trace.input_tensor_id.Length]; + out_gradients.Clear(); } - - bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length; - for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k) + + for(int i = 0, end = in_gradients.Length; i < end; i++) { - if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k; - var id = trace.input_tensor_id[k]; - if (in_gradients[i] != null) + long id = trace.input_tensor_id[i]; + if (in_gradients[i] is not null) { - var unaggregated_grads = gradients[id]; + var unaggregated_grads = gradients.SetDefault(id, new List()); unaggregated_grads.Add(in_gradients[i]); - /*if (unaggregated_grads.Count > kMinAggregateCount) + if(unaggregated_grads.Count > kMinAggregateCount) { - if (!gradients_size.find(id, out var size)) + if(!gradients_size.TryGetValue(id, out var size)) { - size = (long)unaggregated_grads[0].size; + size = NumElements(unaggregated_grads[0]); gradients_size.emplace(id, size); } - - if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes) + if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes) { - throw new NotImplementedException(""); + Tensor grad = AggregateGradients(unaggregated_grads); + unaggregated_grads.Clear(); + unaggregated_grads.Add(grad); } - }*/ + } } - - if (!state.tensor_usage_counts.find(id)) + if(!state.tensor_usage_counts.find(id)) + { continue; - + } state.tensor_usage_counts[id]--; - if (state.tensor_usage_counts[id] > 0) + if(state.tensor_usage_counts[id] > 0) + { continue; - - if (!tensor_tape_.find(id, out var tape_it)) + } + if (!tensor_tape_.TryGetValue(id, out var tape_it)) { - if (gradients.find(id, out var grad_it)) + if (gradients.find(id)) { - // foreach (var g in grad_it) - // DeleteGradient(g); gradients.erase(id); } continue; } - - var op_id = tape_it; - if (op_id == -1) + long op_id = tape_it; + if(op_id == -1) + { continue; - - if (state.op_missing_tensor.find(op_id, out var missing_it)) + } + if(state.op_missing_tensor.find(op_id)) { state.op_missing_tensor[op_id]--; - if (state.op_missing_tensor[op_id] == 0) + if(state.op_missing_tensor[op_id] == 0) + { op_stack.Enqueue(op_id); + } } } } - if (state.op_tape.Count > 0) + if(state.op_tape.Count > 0) + { throw new RuntimeError("Invalid tape state."); - - var result = new Tensor[source_tensor_ids.Length]; - var j = 0; - foreach (var id in source_tensor_ids) + } + Tensor[] result = new Tensor[source_tensor_ids.Length]; + for(int i = 0; i < source_tensor_ids.Length; i++) { - if (gradients.find(id, out var grad_it)) + long tensor_id = source_tensor_ids[i]; + if(!gradients.TryGetValue(tensor_id, out var grad_it)) { - if (grad_it.Count > 1) - result[j] = gen_math_ops.add_n(grad_it.ToArray()); - else - result[j] = grad_it[0]; + result[i] = null; + } + else + { + if(grad_it.Count > 1) + { + Tensor grad = AggregateGradients(grad_it); + grad_it.Clear(); + grad_it.Add(grad); + } + result[i] = grad_it[0]; } - j++; } - return result; } UnorderedMap> FunctionsAcceptingNoneForIndicesMap() { - var m = new UnorderedMap>(); - m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 })); - m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 })); - m.Add("FusedBatchNorm", new UnorderedSet(new[] { 1, 2, 3, 4 })); - return m; + return _functionsAcceptingNoneForIndicesMap; } - UnorderedMapEnumerable> InitialGradients(Tensor[] target_tensor_ids, - UnorderedMap sources_that_are_targets, - Tensor[] output_gradients, + UnorderedMap> InitialGradients(long[] target_tensor_ids, + UnorderedMap sources_that_are_targets, + List output_gradients, TensorTape tensor_tape, OpTape op_tape) { - var result = new UnorderedMapEnumerable>(); - for (int i = 0; i < target_tensor_ids.Length; ++i) + var result = new UnorderedMap>(); + for(int i = 0, end = target_tensor_ids.Length; i < end; i++) { - var id = target_tensor_ids[i]; - if (output_gradients.Length == 0 || output_gradients[i] == null) + long id = target_tensor_ids[i]; + if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null) { - if (tensor_tape.find(id, out var tensor_id) && tensor_id != null) + if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1) { - if (!op_tape.find(tensor_tape[id], out var op_it)) + if(!op_tape.TryGetValue(tensor_it, out var op_it)) + { throw new RuntimeError("Internal state of the gradient tape is invalid: " + - "failed to find operation producing a tensor"); + "failed to find operation producing a tensor."); + } bool found = false; - for (int j = 0; j < op_it.output_tensor_info.Length; ++j) + for(int j = 0; j < op_it.output_tensor_info.Length; j++) { - if (op_it.output_tensor_info[j].GetTensor() == id) + if (op_it.output_tensor_info[j].GetID() == id) { found = true; - var ones = op_it.output_tensor_info[j].OnesLike(); - result[id].Add(ones); + Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]); + result.SetDefault(id, new List()).Add(ones_like); break; } } - if (!found) { - throw new ValueError("Internal state of the gradient tape is invalid: " + - "none of operations outputs match expected tensor"); + throw new RuntimeError("Internal state of the gradient tape is invalid: " + + "none of operations outputs match expected tensor."); } } else { - if (sources_that_are_targets.find(id, out var source_tensor)) - result[id].Add(source_tensor.OnesLike()); + if(sources_that_are_targets.TryGetValue(id, out var source_tensor)) + { + Tensor ones_like = BuildOnesLike(source_tensor); + result.SetDefault(id, new List()).Add(ones_like); + } } } else { - result[id].Add(output_gradients[i]); + result.SetDefault(id, new List()).Add(output_gradients[i]); } } @@ -248,5 +259,26 @@ namespace Tensorflow.Gradients } return result; } + + Tensor BuildOnesLike(TapeTensor t) + { + return t.OnesLike(); + } + + Tensor AggregateGradients(List gradient_tensors) + { + if(gradient_tensors.Count == 0) + { + return gradient_tensors[0]; + } + return tf.add_n(gradient_tensors.ToArray()); + } + + void DeleteGradient(Tensor gradient) + { + // Do not do anything here. Because GC will collect it when it has no reference. + } + + long NumElements(Tensor tensor) => 1; } } diff --git a/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs index 2ab4ddbb..f8f356e7 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs @@ -5,63 +5,62 @@ namespace Tensorflow.Gradients { public partial class Tape { - public BackpropInitialState PrepareBackprop(Tensor[] target, + public BackpropInitialState PrepareBackprop(long[] target, TensorTape tensor_tape, OpTape op_tape, - UnorderedSet sources_set, + UnorderedSet sources_set, bool persistent_tape) { + Stack tensor_stack = new Stack(); + foreach(var t in target) + { + tensor_stack.Push(t); + } BackpropInitialState result = new BackpropInitialState(); - var tensor_stack = new Queue(target); - while (tensor_stack.Count > 0) + while(tensor_stack.Count > 0) { - var tensor_id = tensor_stack.Dequeue(); - - if (!tensor_tape.find(tensor_id, out var op_id)) + long tensor_id = tensor_stack.Pop(); + if(!tensor_tape.TryGetValue(tensor_id, out var op_id)) + { continue; - - if (op_id == -1 || - !op_tape.find(op_id, out var op_it) || - result.op_tape.find(op_id, out var result_op_it)) + } + if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it) + || result.op_tape.find(op_id)) + { continue; - + } result.op_tape.emplace(op_id, op_it); - - foreach (var it in op_it.input_tensor_id) + foreach(var it in op_it.input_tensor_id) { - if (result.tensor_usage_counts.find(it)) + if(result.tensor_usage_counts.find(it)) + { result.tensor_usage_counts[it]++; + } else { result.tensor_usage_counts[it] = 1; if (tensor_tape.find(it)) - tensor_stack.Enqueue(it); + { + tensor_stack.Push(it); + } } } - if (!persistent_tape) - op_tape.Remove(op_id); + { + op_tape.erase(op_id); + } } - - foreach (var pair in result.tensor_usage_counts) + foreach(var pair in result.tensor_usage_counts) { - if (tensor_tape.find(pair.Key, out var it) && it != -1) - result.op_missing_tensor[it] += 1; + if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1) + { + result.op_missing_tensor[it]++; + } } - if (!persistent_tape) { - // Call destructors for all unneeded gradient functions and - // clear the op_tape. We can clear the tape because ownership of - // backward functions that will be used for gradient computation - // has been transferred to `result`. - /*for (const auto&op_pair : *op_tape) { - op_pair.second.backward_function_deleter( - op_pair.second.backward_function); - }*/ op_tape.Clear(); } - return result; } } diff --git a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs index a692f1f4..708b9121 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs @@ -8,34 +8,45 @@ namespace Tensorflow.Gradients public partial class Tape { long next_op_id_ = 0; - UnorderedMap tensor_usage_; + UnorderedMap tensor_usage_; public void RecordOperation(string op_type, - Tensor[] input_tensors, TapeTensor[] output_tensors, + long[] input_tensor_id, + TF_DataType[] input_dtypes, BackwardFunction backward_function) { - if (!ShouldRecord(input_tensors)) + if (!ShouldRecord(input_tensor_id, input_dtypes)) return; - var op_id = next_op_id_++; - foreach (var i in input_tensors) + foreach (var i in input_tensor_id) + { tensor_usage_[i]++; - + } + long op_id = next_op_id_++; + foreach (var o in output_tensors) { tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); - tensor_tape_[o.GetTensor()] = op_id; - tensor_usage_[o.GetTensor()] = 1; + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; } op_tape_[op_id] = new OpTapeEntry { op_type = op_type, - output_tensor_info = output_tensors, - input_tensor_id = input_tensors, + output_tensor_info = output_tensors.ToArray(), + input_tensor_id = input_tensor_id.ToArray(), backward_function = backward_function }; } + + public void RecordOperation(string op_type, + Tensor[] outputs, + Tensor[] inputs, + BackwardFunction backward_function) + { + tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function); + } } } diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs index 15caf81b..648666bb 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Tensorflow.Util; using static Tensorflow.Binding; @@ -29,7 +30,7 @@ namespace Tensorflow.Gradients _created_eagerly = tf.Context.executing_eagerly(); tensor_tape_ = new TensorTape(); op_tape_ = new OpTape(); - tensor_usage_ = new UnorderedMap(); + tensor_usage_ = new UnorderedMap(); if(_created_eagerly) tf.Context.start_step(); // nesting_id = ++tape_nesting_id_counter; @@ -42,29 +43,28 @@ namespace Tensorflow.Gradients public void Watch(Tensor x) { tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); - tensor_tape_.emplace(x, -1); + tensor_tape_.emplace(x.Id, -1); } - public bool ShouldRecord(Tensor[] tensors) + public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes) { - var dtypes = tensors.Select(x => x.dtype).ToArray(); - for (int i = 0; i < tensors.Length; ++i) + Debug.Assert(tensor_ids.Length == tensor_dtypes.Length); + for (int i = 0; i < tensor_ids.Length; ++i) { - if (tensor_tape_.find(tensors[i])) + if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i])) { - if (IsDtypeTrainable(dtypes[i])) - return true; + return true; } } return false; } - public void VariableAccessed(ResourceVariable variable) + public void VariableAccessed(IVariableV1 variable) { Watch(variable.Handle); } - public ResourceVariable[] WatchedVariables() + public IVariableV1[] WatchedVariables() { return null; } diff --git a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs index 210794d8..3ad19768 100644 --- a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs +++ b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs @@ -1,27 +1,63 @@ -using static Tensorflow.Binding; +using OneOf; +using static Tensorflow.Binding; namespace Tensorflow.Gradients { public class TapeTensor { - Tensor tensor; - long id => tensor.Id; - TF_DataType dtype => tensor.dtype; - Shape shape => tensor.shape; + internal Tensor tensor; + internal long id; + internal TF_DataType dtype; + internal OneOf shape; + + public TapeTensor(long id, TF_DataType dtype, Shape shape) + { + this.id = id; + this.dtype = dtype; + this.shape = shape; + } + + public TapeTensor(long id, TF_DataType dtype, Tensor shape) + { + this.id = id; + this.dtype = dtype; + this.shape = shape; + } public TapeTensor(Tensor tensor) { + this.id = tensor.Id; + this.dtype = tensor.dtype; + this.shape = tensor.shape; this.tensor = tensor; } - public long GetID() => tensor.Id; - public Tensor GetTensor() => tensor; + public long GetID() => id; public Tensor ZerosLike() - => tf.zeros(shape: shape, dtype: dtype); + { + if(dtype == dtypes.resource) + { + return null; + } + if(shape.Index == 1) + { + return tf.zeros_like(shape.AsT1); + } + return tf.zeros(shape.AsT0, dtype); + } public Tensor OnesLike() - => tf.ones(shape: shape, dtype: dtype); + { + if (shape.Index == 1) + { + return tf.ones_like(shape.AsT1); + } + return tf.ones(shape.AsT0, dtype); + } + + //public Tensor OnesLike() + // => tf.ones(shape: shape, dtype: dtype); public override string ToString() => $"{id}, {shape}, {dtype.as_numpy_name()}"; diff --git a/src/TensorFlowNET.Core/Gradients/TensorTape.cs b/src/TensorFlowNET.Core/Gradients/TensorTape.cs index b9424f91..3f069082 100644 --- a/src/TensorFlowNET.Core/Gradients/TensorTape.cs +++ b/src/TensorFlowNET.Core/Gradients/TensorTape.cs @@ -7,7 +7,7 @@ namespace Tensorflow.Gradients /// produced this tensor. A value of -1 means that the tensor was directly /// watched and not the result of any operation in the tape. /// - public class TensorTape : UnorderedMap + public class TensorTape : UnorderedMap { } diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 10166911..71d3d9ca 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -704,32 +704,7 @@ namespace Tensorflow public static int PossibleTapeGradientTypes(Tensor[] tensors) { - var tape_set = tf.GetTapeSet(); - bool some_tape_watching = false; - if(tape_set is not null && tape_set.Count > 0) - { - foreach(var tape in tape_set) - { - if (tape.ShouldRecord(tensors)) - { - if(tape.Persistent || some_tape_watching) - { - return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; - } - some_tape_watching = true; - } - } - } - // skip the forward_accumulators. - - if (some_tape_watching) - { - return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; - } - else - { - return POSSIBLE_GRADIENT_TYPES_NONE; - } + return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors); } /// diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 9ef0b95b..ea415969 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -215,6 +215,16 @@ public class FuncGraph : Graph, IDisposable return tensor; } + public void watch_variable(IVariableV1 v) + { + if (_resource_tensor_inputs.Contains(v.Handle)) + { + return; + } + _watched_variables.Add(new WeakReference(v)); + //this = this.outer_graph; + } + Tensor capture_eager_tensor(Tensor tensor, string name) { Tensor graph_const = null; diff --git a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs index 68d6d059..58e7ef8c 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs @@ -4,10 +4,10 @@ public interface IOptimizer { Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); Tensor[] clip_gradients(Tensor[] grads); - void apply_gradients((Tensor, ResourceVariable) grads_and_vars, + void apply_gradients((Tensor, IVariableV1) grads_and_vars, string name = null, bool experimental_aggregate_gradients = true); - void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, + void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, string name = null, bool experimental_aggregate_gradients = true); } diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 43dc8cd4..e5f55631 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -208,9 +208,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); - //[DllImport(TensorFlowLibName)] - //public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); - //[DllImport(TensorFlowLibName)] - //public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); + [DllImport(TensorFlowLibName)] + public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index 9c2e85d1..10547921 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -39,7 +39,7 @@ namespace Tensorflow if (config is null) { - config = function_utils.get_disabled_rewriter_config().ToString(); + config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); } if (executor_type is null) @@ -49,6 +49,8 @@ namespace Tensorflow if (executing_eagerly) { + // TODO(Rinne): implement it. + throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 93a54af0..1dc6504a 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -17,6 +17,7 @@ using System; using System.Linq; using Tensorflow.Contexts; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow @@ -210,7 +211,51 @@ namespace Tensorflow /// A name for the operation (optional). /// A `Tensor`. Has the same type as `value`. public static Tensor fill(Tensor dims, T value, string name = null) - => tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value)); + return _result[0]; + } + catch (Exception) + { + + } + try + { + return fill_eager_fallback(dims, value as Tensor, name, ctx); + } + catch (Exception) + { + + } + } + Dictionary attrs = new Dictionary(); + attrs["dims"] = dims; + attrs["value"] = value; + var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs); + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return result.output; + } + + public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx) + { + object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() }; + var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name); + + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return _result[0]; + } + //=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); /// /// Return the reduction indices for computing gradients of s0 op s1 with broadcast. diff --git a/src/TensorFlowNET.Core/Operations/handle_data_util.cs b/src/TensorFlowNET.Core/Operations/handle_data_util.cs index 66daa5c0..a01efc52 100644 --- a/src/TensorFlowNET.Core/Operations/handle_data_util.cs +++ b/src/TensorFlowNET.Core/Operations/handle_data_util.cs @@ -49,8 +49,10 @@ namespace Tensorflow.Operations target_t.HandleData = handle_data; return; } - // TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) - //c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); + Status status = new(); + var proto = handle_data.ToByteArray(); + c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); + status.Check(true); } public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 3e39338b..c06e822d 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -25,6 +25,7 @@ using static Tensorflow.Binding; using Tensorflow.Operations; using System.Buffers; using Tensorflow.Eager; +using Tensorflow.Graphs; namespace Tensorflow { @@ -302,5 +303,18 @@ namespace Tensorflow // return handle_data_util.get_resource_handle_data(handle); //} } + + public static void variable_accessed(IVariableV1 variable) + { + if (ops.get_default_graph() is FuncGraph func_graph) + { + func_graph.watch_variable(variable); + } + if (variable.Trainable) + { + foreach (var tape in tf.GetTapeSet()) + tape.VariableAccessed(variable); + } + } } } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 4898cca0..935e5545 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -110,7 +110,7 @@ https://tensorflownet.readthedocs.io - + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index fff3cde5..498ffda7 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -30,7 +30,7 @@ namespace Tensorflow { public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); - public Tensor() + protected Tensor() { } @@ -108,6 +108,7 @@ namespace Tensorflow protected unsafe void InitTensor(Shape shape, TF_DataType dtype) { _handle = TF_NewTensor(shape, dtype, null); + _id = ops.uid(); } protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) @@ -116,6 +117,7 @@ namespace Tensorflow _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); else _handle = TF_NewTensor(bytes, shape, dtype); + _id = ops.uid(); } protected unsafe void InitTensor(Array array, Shape? shape = null) @@ -166,6 +168,8 @@ namespace Tensorflow string[] val => StringTensor(val, shape), _ => throw new NotImplementedException("") }; + + _id = ops.uid(); } unsafe SafeTensorHandle InitTensor(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs index 69dd2c10..d6986af3 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -462,6 +462,7 @@ namespace Tensorflow.Training.Saving.SavedModel { IEnumerable _concrete_functions; FunctionSpec _function_spec; + public IEnumerable ConcreteFunctions => _concrete_functions; public RestoredFunction(Func function, string name, FunctionSpec function_spec, IEnumerable concrete_functions): base(function, name, auto_graph: false) { diff --git a/src/TensorFlowNET.Core/Util/UnorderedMap.cs b/src/TensorFlowNET.Core/Util/UnorderedMap.cs index fa2b91fe..219a3c14 100644 --- a/src/TensorFlowNET.Core/Util/UnorderedMap.cs +++ b/src/TensorFlowNET.Core/Util/UnorderedMap.cs @@ -25,6 +25,19 @@ namespace Tensorflow.Util } } + public Tv SetDefault(Tk key, Tv default_value) + { + if(TryGetValue(key, out var res)) + { + return res; + } + else + { + base[key] = default_value; + return base[key]; + } + } + public void push_back(Tk key, Tv value) => this[key] = value; diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index faaa0274..74ce4e8a 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -9,6 +9,7 @@ using System.Diagnostics; using Tensorflow.Checkpoint; using Tensorflow.Training.Saving.SavedModel; using OneOf; +using Tensorflow.Graphs; namespace Tensorflow { @@ -193,6 +194,10 @@ namespace Tensorflow /// void variable_accessed(BaseResourceVariable variable) { + if(ops.get_default_graph() is FuncGraph func_graph) + { + func_graph.watch_variable(variable as IVariableV1); + } if (variable.Trainable) { foreach (var tape in tf.GetTapeSet()) diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 7aadb206..c261f3ce 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -575,12 +575,8 @@ namespace Tensorflow public static HandleData get_resource_handle_data(Tensor graph_op) { - throw new NotImplementedException(); - // This implementation hasn't been checked for some reasons. - // If it throws an exception in the future, please check it. - - //var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); - //return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); + var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); + return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); } public static void dismantle_graph(Graph graph) diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 0a06df2c..79c955b6 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -27,6 +27,7 @@ using Tensorflow.Keras.Utils; using Tensorflow.NumPy; using Tensorflow.Train; using Tensorflow.Training; +using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Util; using static Tensorflow.Binding; @@ -50,7 +51,17 @@ namespace Tensorflow.Keras.Engine /// the layer's weights. /// protected bool built; - public bool Built => built; + public bool Built + { + get + { + return built; + } + internal set + { + built = value; + } + } public bool Trainable => args.Trainable; public TF_DataType DType => args.DType; public bool AutoCast => args.Autocast; @@ -179,6 +190,11 @@ namespace Tensorflow.Keras.Engine } protected List _self_tracked_trackables; + /// + /// If this value is set, the behavior of layer call will be changed to directly calling this function. + /// + public Func? ReplacedCall { get; set; } = null; + public Layer(LayerArgs args) { Initialize(args); @@ -259,6 +275,10 @@ namespace Tensorflow.Keras.Engine /// protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { + if(ReplacedCall is not null) + { + return ReplacedCall(inputs); + } return inputs; } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index 5cf34250..905ea453 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -35,10 +35,6 @@ namespace Tensorflow.Keras.Engine { (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); - //foreach (var variable in TrainableVariables) - //{ - // tape.watch(variable.Handle); - //} var y_pred = Apply(x, training: true); var loss = compiled_loss.Call(y, y_pred); @@ -70,7 +66,7 @@ namespace Tensorflow.Keras.Engine gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables)); gradients = optimizer.clip_gradients(gradients); - optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)), + optimizer.apply_gradients(zip(gradients, trainable_variables), experimental_aggregate_gradients: false); } } diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs index dcd7535f..e49d757a 100644 --- a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Optimizers _set_hyper("decay", args.InitialDecay); } - public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, + public void apply_gradients((Tensor, IVariableV1) grads_and_vars, string name = null, bool experimental_aggregate_gradients = true) => apply_gradients(new[] { grads_and_vars }, @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Optimizers /// /// /// - public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, + public void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, string name = null, bool experimental_aggregate_gradients = true) { @@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Optimizers }); } - void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary> apply_state) + void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary> apply_state) { _resource_apply_dense(var, grad, apply_state); // if var.constraint is not None: @@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Optimizers throw new NotImplementedException("_resource_apply_dense"); } - void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, + void _distributed_apply(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, string name, Dictionary> _apply_state) { diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index aed6769a..9cdd3b50 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -255,6 +255,25 @@ namespace Tensorflow.Keras.Saving /// private void _finalize_saved_model_layers(List layers) { + foreach(var layer in layers) + { + layer.Built = true; + var keras_attr = _get_keras_attr(layer); + if(keras_attr is not Trackable trackable) + { + continue; + } + if (trackable.CustomizedFields.TryGetValue("call_and_return_conditional_losses", out var layer_call)) + { + Debug.Assert(layer_call is RestoredFunction); + var concrete_functions = ((RestoredFunction)layer_call).ConcreteFunctions; + if (concrete_functions is not null && concrete_functions.Count() > 0) + { + layer.ReplacedCall = use_wrapped_call(layer, ((RestoredFunction)layer_call).Apply); + } + } + } + foreach(var layer in layers) { // TODO(Rinne): deal with `RevivedNetwork`. @@ -265,6 +284,12 @@ namespace Tensorflow.Keras.Saving } } + private Func use_wrapped_call(Layer layer, Func call) + { + // TODO(Rinne): revise it. + return call; + } + private void _restore_layer_unconditional_losses(Layer layer) { // TODO(Rinne): implement it. diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs index 4df6613f..bca84a86 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs @@ -85,16 +85,16 @@ namespace Tensorflow.Keras.Saving.SavedModel return _config; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) - { - if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) - { - return base.Call(inputs, state, training); - } - else - { - return (func as Function).Apply(inputs); - } - } + //protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + //{ + // if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) + // { + // return base.Call(inputs, state, training); + // } + // else + // { + // return (func as Function).Apply(inputs); + // } + //} } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index 9d611efe..0ec5d1a8 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -223,7 +223,7 @@ namespace Tensorflow.Keras.Saving.SavedModel //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), - functions.Concat(new string[] { })) + functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })) { } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index cb230605..51962830 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -64,23 +64,19 @@ public class SequentialModelLoad var model = tf.keras.models.load_model(@"Assets/python_func_model"); model.summary(); - var x = tf.random.uniform((8, 784), -1, 1); - var y = model.Apply(x); - Console.WriteLine(y); + model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); - //model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); - - //var data_loader = new MnistModelLoader(); - //var num_epochs = 1; - //var batch_size = 8; + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 8; - //var dataset = data_loader.LoadAsync(new ModelLoadSetting - //{ - // TrainDir = "mnist", - // OneHot = false, - // ValidationSize = 58000, - //}).Result; + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 55000, + }).Result; - //model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); } }