using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Text; using Tensorflow.Functions; using static Tensorflow.Binding; namespace Tensorflow.Graphs { /// /// Graph representing a function body. /// public class FuncGraph : Graph { List inputs; List outputs; Graph outer_graph; string func_name; IntPtr func_handle; public string FuncName => c_api.StringPiece(c_api.TF_FunctionName(func_handle)); /// /// Construct a new FuncGraph. /// public FuncGraph(string name) : base() { outer_graph = ops.get_default_graph(); func_name = name; } public IntPtr ToGraph(Operation[] opers, Operation[] inputs, Operation[] outputs, string[] output_names) { using var status = new Status(); func_handle = c_api.TF_GraphToFunction(_handle, func_name, false, opers.Length, opers.Select(x => (IntPtr)x).ToArray(), inputs.Length, inputs.Select(x => new TF_Output(x, 0)).ToArray(), outputs.Length, outputs.Select(x => new TF_Output(x, 0)).ToArray(), output_names == null || output_names.Length == 0 ? null : output_names, IntPtr.Zero, null, status.Handle); status.Check(true); c_api.TF_GraphCopyFunction(outer_graph, func_handle, IntPtr.Zero, status.Handle); status.Check(true); c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle); status.Check(true); return func_handle; } } }