diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 89828600..a3067182 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Linq; using Tensorflow.Framework.Models; using Tensorflow.Graphs; @@ -11,11 +12,34 @@ namespace Tensorflow.Functions /// public class ConcreteFunction : IDisposable { - public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); IntPtr _handle; + FuncGraph func_graph; + + public string Name + { + get + { + if (func_graph != null) + return func_graph.FuncName; + + return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); + } + } + public Tensor[] Outputs; + public Type ReturnType; public TensorSpec[] OutputStructure; + public ConcreteFunction(string name) + { + func_graph = new FuncGraph(name); + } + + public ConcreteFunction(FuncGraph graph, Dictionary attrs) + { + func_graph = graph; + } + public ConcreteFunction(Func func, TF_DataType dtype) { string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; @@ -28,8 +52,8 @@ namespace Tensorflow.Functions var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); _handle = graph.ToGraph(opers, - new Operation[] { input }, - new Operation[] { output }, + new[] { input }, + new[] { output }, null); } } @@ -48,8 +72,8 @@ namespace Tensorflow.Functions var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); _handle = graph.ToGraph(opers, - new Operation[] { input }, - new Operation[] { output.variant_tensor.op }, + new[] { input }, + new[] { output.variant_tensor }, null); } } @@ -72,12 +96,38 @@ namespace Tensorflow.Functions var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); _handle = graph.ToGraph(opers, - new Operation[] { input1, input2, input3 }, - new Operation[] { outputs.Item1.op, outputs.Item2.op }, + new[] { input1, input2, input3 }, + new[] { outputs.Item1, outputs.Item2 }, null); } } + public void ToGraph(Tensors inputs, Tensors outputs) + { + var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + _handle = func_graph.ToGraph(opers, + inputs, + outputs, + null); + } + + public Tensors Invoke(Tensors inputs) + { + var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); + var (forward_function, args_with_tangents) = forward_backward.Forward(); + Tensors flat_outputs = null; + if (tf.Context.executing_eagerly()) + flat_outputs = forward_function.Call(args_with_tangents); + forward_backward.Record(flat_outputs); + return flat_outputs; + } + + ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) + { + var functions = new FirstOrderTapeGradientFunctions(func_graph, false); + return new ForwardBackwardCall(functions, args, tape_watching: true); + } + public void Dispose() { c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs new file mode 100644 index 00000000..f615f6a4 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -0,0 +1,44 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Graphs; +using static Tensorflow.Binding; + +namespace Tensorflow.Functions +{ + public class EagerDefinedFunction + { + public int _num_outputs; + public string Name => _func_graph.FuncName; + + FuncGraph _func_graph; + public EagerDefinedFunction(string name, FuncGraph graph, + Tensors inputs, Tensors outputs, + Dictionary attrs) + { + _num_outputs = outputs.Length; + + var input_ops = inputs.Select(x => x.op).ToArray(); + var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) + .Select(x => x as Operation).ToArray(); + var output_names = new string[0]; + + _func_graph = new FuncGraph(graph, name, attrs); + _func_graph.ToGraph(operations, inputs, outputs, output_names); + } + + public Tensors Call(Tensors args) + { + var results = tf.Runner.TFE_Execute(tf.Context, + tf.Context.DeviceName, + _func_graph.FuncName, + args, + null, + _num_outputs); + + return results; + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs new file mode 100644 index 00000000..3c099927 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Graphs; + +namespace Tensorflow.Functions +{ + public class FirstOrderTapeGradientFunctions : TapeGradientFunctions + { + public FirstOrderTapeGradientFunctions(FuncGraph func_graph, + bool need_gradients_for_jvps) : base(func_graph, + need_gradients_for_jvps) + { + + } + + public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) + { + var outputs = _func_graph.Outputs; + (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) + = BuildFunctionsForOutputs(outputs, inference_args); + return _forward; + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs b/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs new file mode 100644 index 00000000..cb4d6f1c --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Functions +{ + /// + /// Holds the state of a function call between execution and recording. + /// + public class ForwardBackwardCall + { + TapeGradientFunctions _functions; + Tensors _inference_args; + Tensors _input_tangents; + bool _tape_watching; + + public ForwardBackwardCall(TapeGradientFunctions functions, + Tensors inference_args, + bool tape_watching) + { + _functions = functions; + _inference_args = inference_args; + _tape_watching = tape_watching; + } + + public (EagerDefinedFunction, Tensors) Forward() + { + var forward_function = _functions.Forward(_inference_args); + return (forward_function, _inference_args); + } + + public void Record(Tensors flat_outputs) + { + if (_tape_watching && flat_outputs != null) + _functions.Record(flat_outputs, _inference_args); + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs new file mode 100644 index 00000000..4cd59c92 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -0,0 +1,120 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Graphs; +using static Tensorflow.Binding; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Functions +{ + /// + /// Caches forward and backward functions compatible with eager gradients. + /// + public abstract class TapeGradientFunctions + { + string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; + string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; + string _FORWARD_PREFIX = "__forward_"; + string _BACKWARD_PREFIX = "__backward_"; + string _INFERENCE_PREFIX = "__inference_"; + + protected FuncGraph _func_graph; + protected EagerDefinedFunction _forward; + protected FuncGraph _forward_graph; + protected List _forwardprop_output_indices; + protected int _num_forwardprop_outputs; + protected ConcreteFunction _backward; + + public TapeGradientFunctions(FuncGraph func_graph, + bool need_gradients_for_jvps) + { + _func_graph = func_graph; + } + + public EagerDefinedFunction Forward(Tensors inference_args) + { + return ForwardAndBackwardFunctions(inference_args); + } + + /// + /// Record the function call operation. + /// + /// + /// + public void Record(Tensors flat_outputs, Tensors inference_args) + { + var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); + tf.Runner.RecordGradient(_forward.Name, flat_outputs, new object[0], inference_args, + getBackwardFunction: () => backward_function); + } + + (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs) + { + BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => + { + return new Tensor[0]; + + /*var gradients = ops.gradientFunctions[op_name](new EagerOperation + { + Name = op_name, + NumInputs = op_inputs.Length, + Inputs = op_inputs, + NumOutputs = op_outputs.Length, + Outputs = op_outputs, + SkipInputIndices = unneeded_gradients, + Attrs = attrs + }, output_grads); + + return gradients;*/ + }; + + return (_backward_function_wrapper, flat_outputs); + } + + protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int) + BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args) + { + var trainable_outputs = new List(); + var trainable_indices = new List(); + foreach(var (index, output) in enumerate(outputs)) + { + if (gradients_util.IsTrainable(output)) + { + trainable_outputs.Add(output); + trainable_indices.Add(index); + } + } + + var gradients_wrt_outputs = new List(); + var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"); + foreach (var output in trainable_outputs) + gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); + var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), + _func_graph.Inputs, + grad_ys: gradients_wrt_outputs.ToArray(), + src_graph: _func_graph); + + tf.Context.restore_mode(); + + var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"; + var backward_function_attr = new Dictionary(); + backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; + backwards_graph.Inputs = gradients_wrt_outputs; + backwards_graph.Outputs = gradients_wrt_inputs; + + var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); + + var forward_function_attr = new Dictionary(); + forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; + var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, + _func_graph.Inputs, _func_graph.Outputs, forward_function_attr); + + return (forward_function, _func_graph, backward_function, null, 0); + } + + public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/c_api.function.cs b/src/TensorFlowNET.Core/Functions/c_api.function.cs index bf93ae74..230d85ba 100644 --- a/src/TensorFlowNET.Core/Functions/c_api.function.cs +++ b/src/TensorFlowNET.Core/Functions/c_api.function.cs @@ -47,6 +47,9 @@ namespace Tensorflow string description, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] public static extern IntPtr TF_FunctionName(IntPtr func); diff --git a/src/TensorFlowNET.Core/Gradients/ITape.cs b/src/TensorFlowNET.Core/Gradients/ITape.cs index 69f102a7..279ad876 100644 --- a/src/TensorFlowNET.Core/Gradients/ITape.cs +++ b/src/TensorFlowNET.Core/Gradients/ITape.cs @@ -13,8 +13,6 @@ namespace Tensorflow.Gradients void RecordOperation(string op_type, Tensor[] input_tensors, TapeTensor[] output_tensors, - long[] input_tensor_id, - TF_DataType[] input_dtypes, Func backward_function_getter); void VariableAccessed(ResourceVariable variable); diff --git a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs index c39ec73f..7b0e51f2 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using Tensorflow.Util; using static Tensorflow.tensorflow; using static Tensorflow.Binding; +using System.Linq; namespace Tensorflow.Gradients { @@ -14,18 +15,19 @@ namespace Tensorflow.Gradients public void RecordOperation(string op_type, Tensor[] input_tensors, TapeTensor[] output_tensors, - long[] input_tensor_id, - TF_DataType[] input_dtypes, Func backward_function_getter) { - if (!ShouldRecord(input_tensor_id, input_dtypes)) + var input_ids = input_tensors.Select(x => x.Id).ToArray(); + var input_dtypes = input_tensors.Select(x => x.dtype).ToArray(); + + if (!ShouldRecord(input_ids, input_dtypes)) { return; } long op_id = next_op_id_++; - var ids = new List(input_tensor_id.Length); - foreach (var i in input_tensor_id) + var ids = new List(input_ids.Length); + foreach (var i in input_ids) { tensor_usage_[i]++; ids.Add(i); diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index ce7c309b..1f401a7f 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Graphs; using Tensorflow.Operations.ControlFlows; using static Tensorflow.Binding; @@ -39,7 +40,14 @@ namespace Tensorflow // If src_graph is a _FuncGraph (i.e. a function body), gather it and all // ancestor graphs. This is necessary for correctly handling captured values. + var func_graphs = new List(); var curr_graph = src_graph; + if (src_graph is FuncGraph func_graph) + { + func_graphs.append(func_graph); + curr_graph = func_graph.OuterGraph; + } + if (stop_gradients == null) stop_gradients = new Tensor[0]; @@ -84,7 +92,7 @@ namespace Tensorflow var to_ops = ys.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); - (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs , xs); // Add the initial gradients for the ys. foreach (var (y, grad_y) in zip(ys, grad_ys)) @@ -258,11 +266,8 @@ namespace Tensorflow { var new_grad_ys = new List(); - for (int i = 0; i < grad_ys.Length; i++) + foreach(var (i, (y, grad_y)) in enumerate(zip(ys, grad_ys))) { - var grad_y = grad_ys[i]; - var y = ys[i]; - _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); if (grad_y == null) @@ -272,8 +277,17 @@ namespace Tensorflow var shape = array_ops.shape(y); var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); var fill = gen_array_ops.fill(shape, constant); - new_grad_ys.Add(fill); + new_grad_ys.append(fill); + continue; } + + if (y.dtype.is_floating() || y.dtype.is_integer()) + { + + } + + // Create a grad_y tensor in the name scope of the gradient. + new_grad_ys.append(array_ops.identity(grad_y, name: $"grad_ys_{i}")); } return new_grad_ys.ToArray(); @@ -294,7 +308,11 @@ namespace Tensorflow /// /// /// - private static (Operation[], Dictionary, ControlFlowState) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) + private static (Operation[], Dictionary, ControlFlowState) _PendingCount(List to_ops, + List from_ops, + bool colocate_gradients_with_ops, + List func_graphs, + Tensor[] xs) { // Mark reachable ops from from_ops. var reached_ops = new List(); @@ -511,7 +529,7 @@ namespace Tensorflow /// /// /// - private static void _MarkReachedOps(List from_ops, List reached_ops, List func_graphs) + private static void _MarkReachedOps(List from_ops, List reached_ops, List func_graphs) { Queue queue = new Queue(from_ops); while (queue.Count > 0) @@ -538,7 +556,7 @@ namespace Tensorflow /// /// /// - private static Operation[] _Consumers(Tensor t, List func_graphs) + private static Operation[] _Consumers(Tensor t, List func_graphs) { return t.consumers(); } @@ -647,7 +665,7 @@ namespace Tensorflow } } - private static bool IsTrainable(Tensor tensor) + public static bool IsTrainable(Tensor tensor) { var dtype = tensor.dtype.as_base_dtype(); return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs index 68a7a1d0..bbec00ea 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs @@ -18,10 +18,11 @@ namespace Tensorflow.Graphs var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var func_handle = graph.ToGraph(opers, - new Operation[] { input }, - new Operation[] { output }, + new[] { input }, + new[] { output }, null); } + return (Tensor input) => { @@ -48,11 +49,11 @@ namespace Tensorflow.Graphs var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var func_handle = graph.ToGraph(opers, - new Operation[] { input1, input2 }, - new Operation[] { output }, + new[] { input1, input2 }, + new[] { output }, null); } - + return (Tensor a, Tensor b) => { var result = tf.Runner.TFE_Execute(tf.Context, diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs index 1914f61d..b63f70aa 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Eager; +using Tensorflow.Functions; using static Tensorflow.Binding; namespace Tensorflow.Graphs @@ -10,10 +11,10 @@ namespace Tensorflow.Graphs [AllowChangingInputArguments] public sealed class AutoGraphAttribute : OnMethodBoundaryAspect { - FuncGraph graph; + ConcreteFunction function; Tensors originalInputs; string func_name; - static Dictionary> functions = new Dictionary>(); + static Dictionary functions = new Dictionary(); public override void OnEntry(MethodExecutionArgs args) { @@ -21,22 +22,24 @@ namespace Tensorflow.Graphs if (functions.ContainsKey(func_name)) { + function = functions[func_name]; if (args.Arguments[0] is Tensors tensor_inputs) - args.ReturnValue = functions[func_name](tensor_inputs.ToArray()); + args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs)); else - args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray()); + args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray())); args.FlowBehavior = FlowBehavior.Return; return; } // make function as an Operation by autograph - graph = new FuncGraph(func_name); + // need to restore mode when exits + function = new ConcreteFunction(func_name); // convert to Tensors if (args.Arguments[0] is Tensors inputs) { originalInputs = inputs; - var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape)).ToArray(); + var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape, name: "inputs")).ToArray(); args.Arguments[0] = new Tensors(new_inputs); } else @@ -48,7 +51,7 @@ namespace Tensorflow.Graphs if (args.Arguments[i] is EagerTensor tensor) { originalInputs[i] = tensor; - args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape); + args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); } } } @@ -56,58 +59,30 @@ namespace Tensorflow.Graphs public override void OnExit(MethodExecutionArgs args) { - var returnValue = mark_as_return(args.ReturnValue as Tensors); - - var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); - if (args.ReturnValue is Tensors outputs) { if (args.Arguments[0] is Tensors inputs) - { - graph.ToGraph(opers, - inputs.Select(x => x.op).ToArray(), - outputs.Select(x => x.op).ToArray(), - null); - } + function.ToGraph(inputs, outputs); else - { - graph.ToGraph(opers, - args.Arguments.Select(x => (x as Tensor).op).ToArray(), - outputs.Select(x => x.op).ToArray(), - null); - } + function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); } else - { - graph.ToGraph(opers, - args.Arguments.Select(x => (x as Tensor).op).ToArray(), - new Operation[] { (args.ReturnValue as Tensor).op }, - null); - } - - graph.Dispose(); + function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); - Func function = (x) => - { - var result = tf.Runner.TFE_Execute(tf.Context, - tf.Context.DeviceName, - func_name, - x, - null, - 1); - - return result[0]; - }; // cache function. + function.ReturnType = args.ReturnValue.GetType(); functions[func_name] = function; // run function - args.ReturnValue = function(originalInputs); + args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); } - Tensor mark_as_return(Tensor tensor) + object ConvertReturnValue(Tensors tensors) { - return array_ops.identity(tensor); + if (function.ReturnType == typeof(Tensor)) + return (Tensor)tensors; + else + return tensors; } } } diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index f89536e5..f48a69c0 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -10,13 +10,18 @@ namespace Tensorflow.Graphs /// public class FuncGraph : Graph { - List inputs; - List outputs; Graph outer_graph; + public Graph OuterGraph => outer_graph; + string func_name; + + // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); IntPtr func_handle; public string FuncName => func_name; + public Tensors Inputs { get; set; } + public Tensors Outputs { get; set; } + /// /// Construct a new FuncGraph. /// @@ -29,8 +34,17 @@ namespace Tensorflow.Graphs as_default(); } + public FuncGraph(IntPtr handle, string name) + { + outer_graph = ops.get_default_graph(); + func_name = name; + + tf.Context.graph_mode(); + as_default(); + } + public IntPtr ToGraph(Operation[] opers, - Operation[] inputs, Operation[] outputs, + Tensor[] inputs, Tensor[] outputs, string[] output_names) { using var status = new Status(); @@ -40,9 +54,9 @@ namespace Tensorflow.Graphs opers.Length, opers.Select(x => (IntPtr)x).ToArray(), inputs.Length, - inputs.Select(x => new TF_Output(x, 0)).ToArray(), + inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), outputs.Length, - outputs.Select(x => new TF_Output(x, 0)).ToArray(), + outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), output_names == null || output_names.Length == 0 ? null : output_names, IntPtr.Zero, null, @@ -57,13 +71,18 @@ namespace Tensorflow.Graphs func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); + Inputs = inputs; + // mark_as_return + Outputs = outputs.Select(x => array_ops.identity(x)).ToArray(); + + tf.Context.restore_mode(); + return func_handle; } protected override void DisposeManagedResources() { base.DisposeManagedResources(); - tf.Context.restore_mode(); } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index d79ca07c..bb09bf8e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -69,7 +69,7 @@ namespace Tensorflow throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); var defaultKey = tf.get_default_graph().graph_key; - if (graph_key != defaultKey) + if (tf.get_default_graph().GetType().Name == "Graph" && graph_key != defaultKey) { //Console.WriteLine($"Current graph is not default graph."); throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index e6f6d346..69f92c30 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -218,6 +218,8 @@ namespace Tensorflow { case nameof(Int32): return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); + case nameof(Int64): + return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); default: return null; } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 99be2efe..814c6fc3 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -235,6 +235,14 @@ namespace Tensorflow } var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input }); + + if (tf.Runner.MustRecordGradient()) + { + tf.Runner.RecordGradient("Identity", _op.inputs, new object[] + { + "T", _op.get_attr("T") + }, _op.outputs); + } return _op.output; } @@ -632,8 +640,8 @@ namespace Tensorflow public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new { shape, begin, @@ -645,8 +653,8 @@ namespace Tensorflow ellipsis_mask, new_axis_mask, shrink_axis_mask - }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "StridedSliceGrad", name, null, shape, begin, end, strides, dy, @@ -654,8 +662,22 @@ namespace Tensorflow "end_mask", end_mask, "ellipsis_mask", ellipsis_mask, "new_axis_mask", new_axis_mask, - "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), - shape, begin, end, strides, dy); + "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T"), + "Index", op.get_attr("Index"), + "begin_mask", op.get_attr("begin_mask"), + "end_mask", op.get_attr("end_mask"), + "ellipsis_mask", op.get_attr("ellipsis_mask"), + "new_axis_mask", op.get_attr("new_axis_mask"), + "shrink_axis_mask", op.get_attr("shrink_axis_mask") + }; + tf.Runner.RecordGradient("StridedSliceGrad", op.inputs, attrs, op.outputs); + }, + new Tensors(shape, begin, end, strides, dy)); /// /// Removes dimensions of size 1 from the shape of a tensor. diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 2f91c0d3..3d707215 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -23,6 +23,8 @@ using Tensorflow.Gradients; namespace Tensorflow { + public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); + public partial class tensorflow : ITensorFlowObject { public TF_DataType byte8 = TF_DataType.TF_UINT8; @@ -37,8 +39,6 @@ namespace Tensorflow public TF_DataType chars = TF_DataType.TF_STRING; public TF_DataType @string = TF_DataType.TF_STRING; - public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); - public Status Status; public OpDefLibrary OpDefLib; public Context Context; diff --git a/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs index 84ba57e2..3da828de 100644 --- a/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs @@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Engine // Record the gradient because custom-made ops don't go through the // code-gen'd eager call path - var op_type = op.node_def.Name; + var op_type = op.node_def.Op; tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index 687bcafe..b358c719 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -1,9 +1,9 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; -using static Tensorflow.KerasApi; using static Tensorflow.Binding; using System.Collections.Generic; using System; +using System.Linq; namespace Tensorflow.Keras.Layers { @@ -26,9 +26,20 @@ namespace Tensorflow.Keras.Layers var result = array_ops.reshape(inputs, shape.ToArray()); if (!tf.Context.executing_eagerly()) - // result = result.set_shape(compute_output_shape(inputs.shape)); - throw new NotImplementedException(""); + result.set_shape(compute_output_shape(inputs.shape)); return result; } + + TensorShape compute_output_shape(TensorShape input_shape) + { + if (input_shape.dims[0] == -1) + { + input_shape = input_shape.dims[0]; + var output_shape = input_shape.concatenate(args.TargetShape.dims); + return output_shape; + } + else + throw new NotImplementedException(""); + } } } diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs index 3d7aaaa2..e821d9e7 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; +using Tensorflow.Graphs; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.ManagedAPI @@ -36,7 +37,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI /// /// /// - // [AutoGraph] + [AutoGraph] Tensor Mul(Tensor a, Tensor b) { return a * b;