From 99c748fa5a34b8f234af006c3f95952a17f0ce18 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 31 Dec 2018 11:21:23 -0600 Subject: [PATCH] Completed TEST(CAPI, ImportGraphDef) except graph.Dispose exception threw. --- src/TensorFlowNET.Core/Graphs/Graph.cs | 29 +++++- .../Graphs/TF_ImportGraphDefResults.cs | 17 ++++ src/TensorFlowNET.Core/Graphs/c_api.graph.cs | 36 +++++++- .../Operations/Operation.cs | 35 +++++-- .../Operations/TF_Operation.cs | 13 +++ src/TensorFlowNET.Core/c_api.cs | 2 +- test/TensorFlowNET.UnitTest/GraphTest.cs | 92 ++++++++++++++++--- 7 files changed, 194 insertions(+), 30 deletions(-) create mode 100644 src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs create mode 100644 src/TensorFlowNET.Core/Operations/TF_Operation.cs diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 7ea8085d..c5b18ae1 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -179,31 +179,50 @@ namespace Tensorflow TF_Output[] return_outputs = new TF_Output[num_return_outputs]; for (int i = 0; i < num_return_outputs; i++) { - return_outputs[i] = Marshal.PtrToStructure(return_output_handle + (Marshal.SizeOf() * i)); + var handle = return_output_handle + (Marshal.SizeOf() * i); + return_outputs[i] = Marshal.PtrToStructure(handle); } return return_outputs; } - public Operation[] ReturnOperations(IntPtr results) + public unsafe Operation[] ReturnOperations(IntPtr results) { - IntPtr return_oper_handle = IntPtr.Zero; + TF_Operation return_oper_handle = new TF_Operation(); int num_return_opers = 0; - c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_opers, ref return_oper_handle); + c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); Operation[] return_opers = new Operation[num_return_opers]; for (int i = 0; i < num_return_opers; i++) { - // return_opers[i] = Marshal.PtrToStructure(return_oper_handle + (Marshal.SizeOf() * i)); + var handle = return_oper_handle.node + Marshal.SizeOf() * i; + return_opers[i] = new Operation(*(IntPtr*)handle); } return return_opers; } + public Operation OperationByName(string operName) + { + return c_api.TF_GraphOperationByName(_handle, operName); + } + public Operation[] get_operations() { return _nodes_by_name.Values.Select(x => x).ToArray(); } + public GraphDef ToGraphDef() + { + var s = new Status(); + var buffer = new Buffer(); + c_api.TF_GraphToGraphDef(_handle, buffer, s); + s.Check(); + var def = GraphDef.Parser.ParseFrom(buffer); + buffer.Dispose(); + s.Dispose(); + return def; + } + public void Dispose() { c_api.TF_DeleteGraph(_handle); diff --git a/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs new file mode 100644 index 00000000..d857f56f --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + [StructLayout(LayoutKind.Sequential)] + public struct TF_ImportGraphDefResults + { + public IntPtr return_tensors; + public IntPtr return_nodes; + public IntPtr missing_unused_key_names; + public IntPtr missing_unused_key_indexes; + public IntPtr missing_unused_key_names_data; + } +} diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 4c68f68b..e9e05583 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -22,6 +22,13 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts); + /// + /// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteImportGraphDefResults(IntPtr results); + [DllImport(TensorFlowLibName)] public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); @@ -91,9 +98,9 @@ namespace Tensorflow /// Write out a serialized representation of `graph` (as a GraphDef protocol /// message) to `output_graph_def` (allocated by TF_NewBuffer()). /// - /// - /// - /// + /// TF_Graph* + /// TF_Buffer* + /// TF_Status* [DllImport(TensorFlowLibName)] public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status); @@ -110,6 +117,15 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); + /// + /// Cause the imported graph to have a control dependency on `oper`. `oper` + /// should exist in the graph being imported into. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsAddControlDependency(IntPtr opts, IntPtr oper); + /// /// Set any imported nodes with input `src_name:src_index` to have that input /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, @@ -163,6 +179,18 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); + /// + /// Set any imported nodes with control input `src_name` to have that input + /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, + /// `dst` references an operation already existing in the graph being imported + /// into. `src_name` is copied and has no lifetime requirements. + /// + /// TF_ImportGraphDefOptions* + /// const char* + /// TF_Operation* + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsRemapControlDependency(IntPtr opts, string src_name, IntPtr dst); + /// /// Set the prefix to be prepended to the names of nodes in `graph_def` that will /// be imported into `graph`. `prefix` is copied and has no lifetime @@ -182,7 +210,7 @@ namespace Tensorflow /// int* /// TF_Operation*** [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref IntPtr opers); + public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers); /// /// Fetches the return outputs requested via diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index a41974f2..d9b91c8d 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -18,13 +18,16 @@ namespace Tensorflow public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle)); public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); + public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); + public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); public int NumInputs => c_api.TF_OperationNumInputs(_handle); + public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { @@ -42,21 +45,41 @@ namespace Tensorflow public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); - public Operation[] ControlInputs(int max_control_inputs) + public unsafe Operation[] GetControlInputs() { var control_inputs = new Operation[NumControlInputs]; - var control_input_handles = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlInputs); - c_api.TF_OperationGetControlInputs(_handle, control_input_handles, max_control_inputs); + + if(NumControlInputs > 0) + { + IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf()); + c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); + for (int i = 0; i < NumControlInputs; i++) + { + var handle = control_input_handle + Marshal.SizeOf() * i; + control_inputs[i] = new Operation(*(IntPtr*)handle); + } + } + return control_inputs; } public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); - public Operation[] ControlOutputs(int max_control_outputs) + public unsafe Operation[] GetControlOutputs() { var control_outputs = new Operation[NumControlOutputs]; - var control_output_handles = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlOutputs); - c_api.TF_OperationGetControlInputs(_handle, control_output_handles, max_control_outputs); + + if(NumControlOutputs > 0) + { + IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf()); + c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); + for (int i = 0; i < NumControlInputs; i++) + { + var handle = control_output_handle + Marshal.SizeOf() * i; + control_outputs[i] = new Operation(*(IntPtr*)handle); + } + } + return control_outputs; } diff --git a/src/TensorFlowNET.Core/Operations/TF_Operation.cs b/src/TensorFlowNET.Core/Operations/TF_Operation.cs new file mode 100644 index 00000000..6d43bdb2 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/TF_Operation.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + [StructLayout(LayoutKind.Sequential)] + public struct TF_Operation + { + public IntPtr node; + } +} diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index 77cc01af..fd5952c8 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -16,7 +16,7 @@ namespace Tensorflow /// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) /// struct => struct (TF_Output output) => (TF_Output output) - /// struct* => struct (TF_Output* output) => (TF_Output[] output) + /// struct* => struct[] (TF_Output* output) => (TF_Output[] output) /// struct* => struct* for ref /// const char* => string /// int32_t => int diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 515b503c..2e813ad8 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -228,9 +228,9 @@ namespace TensorFlowNET.UnitTest c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); - Operation scalar = c_api.TF_GraphOperationByName(graph, "imported/scalar"); - Operation feed = c_api.TF_GraphOperationByName(graph, "imported/feed"); - Operation neg = c_api.TF_GraphOperationByName(graph, "imported/neg"); + Operation scalar = graph.OperationByName("imported/scalar"); + Operation feed = graph.OperationByName("imported/feed"); + Operation neg = graph.OperationByName("imported/neg"); // Test basic structure of the imported graph. EXPECT_EQ(0, scalar.NumInputs); @@ -243,19 +243,19 @@ namespace TensorFlowNET.UnitTest // Test that we can't see control edges involving the source and sink nodes. EXPECT_EQ(0, scalar.NumControlInputs); - EXPECT_EQ(0, scalar.ControlInputs(100).Length); + EXPECT_EQ(0, scalar.GetControlInputs().Length); EXPECT_EQ(0, scalar.NumControlOutputs); - EXPECT_EQ(0, scalar.ControlOutputs(100).Length); + EXPECT_EQ(0, scalar.GetControlOutputs().Length); EXPECT_EQ(0, feed.NumControlInputs); - EXPECT_EQ(0, feed.ControlInputs(100).Length); + EXPECT_EQ(0, feed.GetControlInputs().Length); EXPECT_EQ(0, feed.NumControlOutputs); - EXPECT_EQ(0, feed.ControlOutputs(100).Length); + EXPECT_EQ(0, feed.GetControlOutputs().Length); EXPECT_EQ(0, neg.NumControlInputs); - EXPECT_EQ(0, neg.ControlInputs(100).Length); + EXPECT_EQ(0, neg.GetControlInputs().Length); EXPECT_EQ(0, neg.NumControlOutputs); - EXPECT_EQ(0, neg.ControlOutputs(100).Length); + EXPECT_EQ(0, neg.GetControlOutputs().Length); // Import it again, with an input mapping, return outputs, and a return // operation, into the same graph. @@ -271,9 +271,9 @@ namespace TensorFlowNET.UnitTest var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); - Operation scalar2 = c_api.TF_GraphOperationByName(graph, "imported2/scalar"); - Operation feed2 = c_api.TF_GraphOperationByName(graph, "imported2/feed"); - Operation neg2 = c_api.TF_GraphOperationByName(graph, "imported2/neg"); + Operation scalar2 = graph.OperationByName("imported2/scalar"); + Operation feed2 = graph.OperationByName("imported2/feed"); + Operation neg2 = graph.OperationByName("imported2/neg"); // Check input mapping neg_input = neg.Input(0); @@ -289,8 +289,72 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(0, return_outputs[1].index); // Check return operation - var num_return_opers = graph.ReturnOperations(results); - ASSERT_EQ(1, num_return_opers); + var return_opers = graph.ReturnOperations(results); + ASSERT_EQ(1, return_opers.Length); + EXPECT_EQ(scalar2, return_opers[0]); // not remapped + c_api.TF_DeleteImportGraphDefResults(results); + + // Import again, with control dependencies, into the same graph. + c_api.TF_DeleteImportGraphDefOptions(opts); + opts = c_api.TF_NewImportGraphDefOptions(); + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); + c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); + c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); + c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + var scalar3 = graph.OperationByName("imported3/scalar"); + var feed3 = graph.OperationByName("imported3/feed"); + var neg3 = graph.OperationByName("imported3/neg"); + ASSERT_TRUE(scalar3 != IntPtr.Zero); + ASSERT_TRUE(feed3 != IntPtr.Zero); + ASSERT_TRUE(neg3 != IntPtr.Zero); + + // Check that newly-imported scalar and feed have control deps (neg3 will + // inherit them from input) + var control_inputs = scalar3.GetControlInputs(); + ASSERT_EQ(2, scalar3.NumControlInputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed2, control_inputs[1]); + + control_inputs = feed3.GetControlInputs(); + ASSERT_EQ(2, feed3.NumControlInputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed2, control_inputs[1]); + + // Export to a graph def so we can import a graph with control dependencies + graph_def.Dispose(); + graph_def = new Buffer(); + c_api.TF_GraphToGraphDef(graph, graph_def, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + // Import again, with remapped control dependency, into the same graph + c_api.TF_DeleteImportGraphDefOptions(opts); + opts = c_api.TF_NewImportGraphDefOptions(); + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); + c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); + c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + var scalar4 = graph.OperationByName("imported4/imported3/scalar"); + var feed4 = graph.OperationByName("imported4/imported2/feed"); + + // Check that imported `imported3/scalar` has remapped control dep from + // original graph and imported control dep + control_inputs = scalar4.GetControlInputs(); + ASSERT_EQ(2, scalar4.NumControlInputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed4, control_inputs[1]); + + c_api.TF_DeleteImportGraphDefOptions(opts); + c_api.TF_DeleteBuffer(graph_def); + + // Can add nodes to the imported graph without trouble. + c_test_util.Add(feed, scalar, graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + //graph.Dispose(); + s.Dispose(); } } }