From e19e59b3ddd1401b0bdd384f73b4630ca0e693a3 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 26 Dec 2020 09:58:04 -0600 Subject: [PATCH] Add FunctionCallOptions. --- src/TensorFlowNET.Core/Contexts/Context.cs | 2 ++ .../Contexts/FunctionCallOptions.cs | 20 +++++++++++++++++++ .../Eager/EagerRunner.TFE_FastPathExecute.cs | 2 +- src/TensorFlowNET.Core/Eager/c_api.eager.cs | 2 +- .../Functions/ConcreteFunction.cs | 18 +++++++++++++++++ .../Functions/TapeGradientFunctions.cs | 16 +++++++-------- .../ops.gradient_function_mapping.cs | 4 ++-- src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 8 +++++--- src/TensorFlowNET.Core/Graphs/Graph.cs | 8 ++++---- src/TensorFlowNET.Keras/Engine/Functional.cs | 2 +- 10 files changed, 62 insertions(+), 20 deletions(-) create mode 100644 src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 24e24a01..226625b7 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -35,6 +35,7 @@ namespace Tensorflow.Contexts public string ScopeName { get; set; } = ""; bool initialized = false; ContextSwitchStack context_switches; + public FunctionCallOptions FunctionCallOptions { get; } public SafeContextHandle Handle { get; } @@ -44,6 +45,7 @@ namespace Tensorflow.Contexts status.Check(true); context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); initialized = true; + FunctionCallOptions = new FunctionCallOptions(); } /// diff --git a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs new file mode 100644 index 00000000..3fee5d92 --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Google.Protobuf; +using Google.Protobuf.Collections; + +namespace Tensorflow.Contexts +{ + public class FunctionCallOptions + { + public string config_proto_serialized() + { + var config = new ConfigProto + { + AllowSoftPlacement = true, + }; + return config.ToByteString().ToStringUtf8(); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index 1903e5cd..aa56ede5 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -371,7 +371,7 @@ namespace Tensorflow.Eager switch (type) { case TF_AttrType.TF_ATTR_STRING: - c_api.TFE_OpSetAttrString(op, key, value.ToString(), (uint)value.ToString().Length); + c_api.TFE_OpSetAttrString(op, key, value.ToString(), (ulong)value.ToString().Length); break; case TF_AttrType.TF_ATTR_TYPE: c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value); diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 8120ed66..f707804e 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -241,7 +241,7 @@ namespace Tensorflow /// const void* /// size_t [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); + public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, ulong length); [DllImport(TensorFlowLibName)] public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 90cb0494..4f701e2e 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -14,6 +14,7 @@ namespace Tensorflow.Functions { IntPtr _handle; FuncGraph func_graph; + public Tensor[] CapturedInputs => func_graph.external_captures(); public string Name { @@ -38,6 +39,8 @@ namespace Tensorflow.Functions public ConcreteFunction(FuncGraph graph, Dictionary attrs) { func_graph = graph; + + ToGraph(graph.Inputs, graph.Outputs); } public ConcreteFunction(Func func, TF_DataType dtype) @@ -124,6 +127,21 @@ namespace Tensorflow.Functions return flat_outputs; } + public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) + { + var new_args = new List(); + new_args.AddRange(args); + new_args.AddRange(captured_inputs); + args = new_args.ToArray(); + + var attrs = new object[] + { + "executor_type", "", + "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() + }; + return tf.Runner.Execute(tf.Context, func_graph.FuncName, 1, args, attrs); + } + ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) { var functions = new FirstOrderTapeGradientFunctions(func_graph, false); diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 19ae5fbc..0a98d91d 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -48,7 +48,7 @@ namespace Tensorflow.Functions getBackwardFunction: () => backward_function); } - (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs) + (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) { BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => { @@ -61,10 +61,11 @@ namespace Tensorflow.Functions processed_args.add(arg); input_index += 1; } - return output_grads;// backward.Invoke(processed_args.ToArray()); + + return backward.CallFlat(processed_args.ToArray(), outputs); }; - return (_backward_function_wrapper, flat_outputs); + return (_backward_function_wrapper, outputs); } protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int) @@ -82,7 +83,7 @@ namespace Tensorflow.Functions } var gradients_wrt_outputs = new List(); - var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"); + var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); foreach (var output in trainable_outputs) gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), @@ -90,16 +91,15 @@ namespace Tensorflow.Functions grad_ys: gradients_wrt_outputs.ToArray(), src_graph: _func_graph); - tf.Context.restore_mode(); - - var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"; + var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; var backward_function_attr = new Dictionary(); backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; + gradients_wrt_outputs.append(backwards_graph.internal_captures()); backwards_graph.Inputs = gradients_wrt_outputs; backwards_graph.Outputs = gradients_wrt_inputs; var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); - + var forward_function_attr = new Dictionary(); forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index 5a3f5835..6de42037 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -49,14 +49,14 @@ namespace Tensorflow RegisterGradientFunction(m.GetCustomAttribute().Name, (oper, out_grads) => { - tf.Logger.Debug($"Caculate Gradient: {m.Name}"); + tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); var results = g.InvokeMember(m.Name, BindingFlags.InvokeMethod, null, null, args: new object[] { oper, out_grads }) as Tensor[]; foreach (var result in results.Where(x => x != null)) - tf.Logger.Debug($"{result.TensorShape}"); + tf.Logger.Debug($"Gradient: {result.name} {result.TensorShape}"); return results; } ); diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 106c51fe..5eedbd8b 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -26,7 +26,9 @@ namespace Tensorflow.Graphs public Tensors Outputs { get; set; } public Dictionary Attrs { get; set; } - Dictionary _captures = new Dictionary(); + // new Dictionary _captures = new Dictionary(); + // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); + /// /// Construct a new FuncGraph. /// @@ -129,7 +131,7 @@ namespace Tensorflow.Graphs Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) { Tensor placeholder = null; - if (!_captures.ContainsKey(tensor.Id)) + if (!_captures.Contains(tensor.Id)) { placeholder = _create_substitute_placeholder(tensor, name: name, @@ -139,7 +141,7 @@ namespace Tensorflow.Graphs } else { - placeholder = _captures[tensor.Id].Item1; + placeholder = (((Tensor, Tensor))_captures[tensor.Id]).Item2; } BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index ce66966c..5ea2cc07 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -557,16 +557,16 @@ namespace Tensorflow public Tensor[] external_captures() { - Tensor[] captures = new Tensor[this._captures.Count]; - ICollection inner = this._captures.Keys; // c[0] + Tensor[] captures = new Tensor[_captures.Count]; + ICollection inner = _captures.Keys; // c[0] inner.CopyTo(captures, 0); return captures; } public Tensor[] internal_captures() { - Tensor[] captures = new Tensor[this._captures.Count]; - ICollection inner = this._captures.Values; // c[1] + Tensor[] captures = new Tensor[_captures.Count]; + ICollection inner = _captures.Values; // c[1] inner.CopyTo(captures, 0); return captures; } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index b2b109ba..a02f46d2 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -340,7 +340,7 @@ namespace Tensorflow.Keras.Engine tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); var outputs = node.Layer.Apply(layer_inputs, is_training: training); foreach (var output in outputs.Where(x => x != null)) - tf.Logger.Debug($"{output.TensorShape}"); + tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); // Update tensor_dict for next input foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) tensor_dict[x_id] = new Queue(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));