using MethodBoundaryAspect.Fody.Attributes; using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Eager; using Tensorflow.Functions; using static Tensorflow.Binding; namespace Tensorflow.Graphs { /// /// func_graph.py func_graph_from_py_func /// [AllowChangingInputArguments] public sealed class AutoGraphAttribute : OnMethodBoundaryAspect { ConcreteFunction function; Tensors originalInputs; string func_name; static Dictionary functions = new Dictionary(); public override void OnEntry(MethodExecutionArgs args) { // TODO: func_name can be cache in FullName + Args func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; if (functions.ContainsKey(func_name)) { function = functions[func_name]; if (args.Arguments[0] is Tensors tensor_inputs) args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs)); else args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray())); args.FlowBehavior = FlowBehavior.Return; return; } // make function as an Operation by autograph // need to restore mode when exits function = new ConcreteFunction(func_name); function.Enter(); // convert to Tensors if (args.Arguments[0] is Tensors inputs) { originalInputs = inputs; var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: "inputs")).ToArray(); args.Arguments[0] = new Tensors(new_inputs); } else { originalInputs = new Tensors(); // convert args to placeholder for (var i = 0; i < args.Arguments.Length; i++) { if (args.Arguments[i] is EagerTensor tensor) { originalInputs.Add(tensor); args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.shape, name: "inputs"); } } } } public override void OnExit(MethodExecutionArgs args) { if (args.ReturnValue is Tensors outputs) { Tensors inputs = null; outputs = mark_as_return(outputs); if (args.Arguments[0] is Tensors inputs1) inputs = inputs1; else inputs = args.Arguments.Select(x => x as Tensor).ToArray(); inputs = inputs.Where(x => x.op.OpType == "Placeholder" && x.op.name.StartsWith("inputs")).ToArray(); function.ToGraph(inputs, outputs); } else if (args.ReturnValue is Tensor output) { var inputs = args.Arguments.Select(x => x as Tensor) .Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs")) .ToArray(); var outputs2 = array_ops.identity(output); function.ToGraph(inputs, outputs2); } function.Exit(); // cache function. function.ReturnType = args.ReturnValue.GetType(); functions[func_name] = function; // run function args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs)); } object ConvertReturnValue(Tensors tensors) { if (function.ReturnType == typeof(Tensor)) return (Tensor)tensors; else return tensors; } /// /// Acts like identity but marks the `Tensor` as a return value. /// /// /// public Tensors mark_as_return(Tensors tensors) { if (tensors == null) return null; var result = new Tensors(); foreach (var tensor in tensors) result.Add(array_ops.identity(tensor)); return result; } } }