diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs new file mode 100644 index 00000000..707195b3 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public partial class Graph + { + public Buffer ToGraphDef(Status s) + { + var buffer = new Buffer(); + c_api.TF_GraphToGraphDef(_handle, buffer, s); + s.Check(); + // var def = GraphDef.Parser.ParseFrom(buffer); + // buffer.Dispose(); + + return buffer; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs new file mode 100644 index 00000000..b22bad4b --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -0,0 +1,28 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public partial class Graph + { + public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s) + { + var num_return_outputs = opts.NumReturnOutputs; + var return_outputs = new TF_Output[num_return_outputs]; + TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(Marshal.SizeOf() * 2); + + c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); + for (int i = 0; i < num_return_outputs; i++) + { + var handle = return_output_handle + i * Marshal.SizeOf(); + + return_outputs[i] = new TF_Output((*handle).oper, (*handle).index); + } + + return return_outputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index c5b18ae1..b6f2b8fc 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -13,7 +13,7 @@ namespace Tensorflow /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// https://www.tensorflow.org/guide/graphs /// - public class Graph : IDisposable + public partial class Graph : IDisposable { private IntPtr _handle; private Dictionary _nodes_by_id; @@ -211,18 +211,6 @@ namespace Tensorflow return _nodes_by_name.Values.Select(x => x).ToArray(); } - public GraphDef ToGraphDef() - { - var s = new Status(); - var buffer = new Buffer(); - c_api.TF_GraphToGraphDef(_handle, buffer, s); - s.Check(); - var def = GraphDef.Parser.ParseFrom(buffer); - buffer.Dispose(); - s.Dispose(); - return def; - } - public void Dispose() { c_api.TF_DeleteGraph(_handle); diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs new file mode 100644 index 00000000..1e46885a --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class ImportGraphDefOptions : IDisposable + { + private IntPtr _handle; + public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); + + public ImportGraphDefOptions() + { + _handle = c_api.TF_NewImportGraphDefOptions(); + } + + public ImportGraphDefOptions(IntPtr handle) + { + _handle = handle; + } + + public void AddReturnOutput(string name, int index) + { + c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); + } + + public void Dispose() + { + c_api.TF_DeleteImportGraphDefOptions(_handle); + } + + public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; + public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); + } +} diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index e9e05583..b2988a22 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -45,6 +45,24 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); + /// + /// Import the graph serialized in `graph_def` into `graph`. + /// Convenience function for when only return outputs are needed. + /// + /// `num_return_outputs` must be the number of return outputs added (i.e. the + /// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If + /// `num_return_outputs` is non-zero, `return_outputs` must be of length + /// `num_return_outputs`. Otherwise it can be null. + /// + /// TF_Graph* graph + /// const TF_Buffer* + /// const TF_ImportGraphDefOptions* + /// TF_Output* + /// int + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status); + /// /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and /// a bad status on error. Otherwise, returns a populated diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index edc8f5bd..00293c2f 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -357,10 +357,55 @@ namespace TensorFlowNET.UnitTest s.Dispose(); } + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)` + /// [TestMethod] public void c_api_ImportGraphDef_WithReturnOutputs() { + var s = new Status(); + var graph = new Graph(); + + // Create a graph with two nodes: x and 3 + var feed = c_test_util.Placeholder(graph, s); + EXPECT_EQ(feed, graph.OperationByName("feed")); + var scalar = c_test_util.ScalarConst(3, graph, s); + EXPECT_EQ(scalar, graph.OperationByName("scalar")); + var neg = c_test_util.Neg(scalar, graph, s); + EXPECT_EQ(neg, graph.OperationByName("neg")); + + // Export to a GraphDef. + var graph_def = graph.ToGraphDef(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Import it in a fresh graph with return outputs. + graph.Dispose(); + graph = new Graph(); + var opts = new ImportGraphDefOptions(); + opts.AddReturnOutput("feed", 0); + opts.AddReturnOutput("scalar", 0); + EXPECT_EQ(2, opts.NumReturnOutputs); + var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + scalar = graph.OperationByName("scalar"); + feed = graph.OperationByName("feed"); + neg = graph.OperationByName("neg"); + ASSERT_TRUE(scalar != IntPtr.Zero); + ASSERT_TRUE(feed != IntPtr.Zero); + ASSERT_TRUE(neg != IntPtr.Zero); + + // Check return outputs + EXPECT_EQ(feed, return_outputs[0].oper); + EXPECT_EQ(0, return_outputs[0].index); + EXPECT_EQ(scalar, return_outputs[1].oper); + EXPECT_EQ(0, return_outputs[1].index); + + opts.Dispose(); + graph_def.Dispose(); + graph.Dispose(); + s.Dispose(); } } }