using System; using System.Diagnostics; using System.Linq; using static Tensorflow.Binding; namespace Tensorflow.Graphs { public class AutoGraph { public Func to_graph(Func func, TF_DataType dtype = TF_DataType.TF_INT32) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); var input = tf.placeholder(dtype); var output = func(input); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); graph.ToGraph(opers, new[] { input }, new[] { output }, null); graph.Exit(); return (Tensor input) => { if (tf.executing_eagerly()) { var result = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, func_name, new[] { input }, null, 1); return result[0]; } using (var s = tf.Session(input.graph)) { var output = func(input); return output; } }; } public Func to_graph(Func func, params TF_DataType[] dtypes) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32); var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32); var output = func(input1, input2); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); graph.ToGraph(opers, new[] { input1, input2 }, new[] { output }, null); graph.Exit(); return (Tensor a, Tensor b) => { if (tf.executing_eagerly()) { var result = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, func_name, new[] { a, b }, null, 1); return result[0]; } using (var s = tf.Session(a.graph)) { Debug.Assert(a.graph == b.graph); var output = func(a, b); return output; } }; } } }