From 7910b79258013e46888d9c4448521e6794ac8406 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 30 Dec 2018 10:55:04 -0600 Subject: [PATCH] OutputConsumers --- src/TensorFlowNET.Core/Graphs/Graph.cs | 2 +- .../Operations/Operation.cs | 40 +++++++++++++--- src/TensorFlowNET.Core/Operations/TF_Input.cs | 6 +++ .../Operations/c_api.ops.cs | 27 +++++++++++ .../Sessions/_FetchHandler.cs | 2 +- src/TensorFlowNET.Core/c_api.cs | 1 + test/TensorFlowNET.UnitTest/GraphTest.cs | 48 ++++++++++++++++--- 7 files changed, 112 insertions(+), 14 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 1fbc3789..69408446 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -126,7 +126,7 @@ namespace Tensorflow } else if (tensor_or_op is Operation) { - return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); + return !_unfetchable_ops.Contains((tensor_or_op as Operation).Name); } return false; diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 550925e6..02d29e08 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -15,14 +15,31 @@ namespace Tensorflow private Status status = new Status(); - public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); - public string optype => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); - public string device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); + public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle)); + public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); + public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); - public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); - public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); + public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); + public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); + public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); + public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); + public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); public int NumInputs => c_api.TF_OperationNumInputs(_handle); - public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); + public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); + public TF_Input[] OutputConsumers(int index, int max_consumers) + { + IntPtr handle = IntPtr.Zero; + int size = Marshal.SizeOf(); + int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers); + var consumers = new TF_Input[num]; + for(int i = 0; i < num; i++) + { + consumers[0] = Marshal.PtrToStructure(handle + i * size); + } + + return consumers; + } + public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); @@ -92,5 +109,16 @@ namespace Tensorflow { return op._handle; } + + public override bool Equals(object obj) + { + switch (obj) + { + case IntPtr val: + return val == _handle; + } + + return base.Equals(obj); + } } } diff --git a/src/TensorFlowNET.Core/Operations/TF_Input.cs b/src/TensorFlowNET.Core/Operations/TF_Input.cs index 4adcd040..0d49d0e5 100644 --- a/src/TensorFlowNET.Core/Operations/TF_Input.cs +++ b/src/TensorFlowNET.Core/Operations/TF_Input.cs @@ -8,6 +8,12 @@ namespace Tensorflow [StructLayout(LayoutKind.Sequential)] public struct TF_Input { + public TF_Input(IntPtr oper, int index) + { + this.oper = oper; + this.index = index; + } + public IntPtr oper; public int index; } diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 7bd829f1..02839147 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -49,6 +49,22 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); + /// + /// TF_Output producer = TF_OperationInput(consumer); + /// There is an edge from producer.oper's output (given by + /// producer.index) to consumer.oper's input (given by consumer.index). + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern TF_Output TF_OperationInput(TF_Input oper_in); + + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationInputListLength(IntPtr oper, string arg_name, IntPtr status); + + [DllImport(TensorFlowLibName)] + public static extern TF_DataType TF_OperationInputType(TF_Input oper_in); + [DllImport(TensorFlowLibName)] public static extern IntPtr TF_OperationName(IntPtr oper); @@ -87,6 +103,17 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out); + /// + /// Get list of all current consumers of a specific output of an + /// operation. + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers); + [DllImport(TensorFlowLibName)] public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 347a1293..44f8f261 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -43,7 +43,7 @@ namespace Tensorflow { if (!graph.is_fetchable(op)) { - throw new Exception($"Operation {op.name} has been marked as not fetchable."); + throw new Exception($"Operation {op.Name} has been marked as not fetchable."); } } diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index 70d933a3..b6de4639 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -17,6 +17,7 @@ namespace Tensorflow /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) /// struct => struct (TF_Output output) => (TF_Output output) /// struct* => struct (TF_Output* output) => (TF_Output[] output) + /// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[] /// const char* => string /// int32_t => int /// int64_t* => long[] diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 9dab6da7..63b25f59 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -21,14 +21,14 @@ namespace TensorFlowNET.UnitTest // Make a placeholder operation. var feed = c_test_util.Placeholder(graph, s); - Assert.AreEqual("feed", feed.name); - Assert.AreEqual("Placeholder", feed.optype); - Assert.AreEqual("", feed.device); + Assert.AreEqual("feed", feed.Name); + Assert.AreEqual("Placeholder", feed.OpType); + Assert.AreEqual("", feed.Device); Assert.AreEqual(1, feed.NumOutputs); - Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); - Assert.AreEqual(1, feed.OutputListLength); + Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0)); + Assert.AreEqual(1, feed.OutputListLength("output")); Assert.AreEqual(0, feed.NumInputs); - Assert.AreEqual(0, feed.NumConsumers); + Assert.AreEqual(0, feed.OutputNumConsumers(0)); Assert.AreEqual(0, feed.NumControlInputs); Assert.AreEqual(0, feed.NumControlOutputs); @@ -44,9 +44,45 @@ namespace TensorFlowNET.UnitTest // Make a constant oper with the scalar "3". var three = c_test_util.ScalarConst(3, graph, s); + Assert.AreEqual(TF_Code.TF_OK, s.Code); // Add oper. var add = c_test_util.Add(feed, three, graph, s); + Assert.AreEqual(TF_Code.TF_OK, s.Code); + + // Test TF_Operation*() query functions. + Assert.AreEqual("add", add.Name); + Assert.AreEqual("AddN", add.OpType); + Assert.AreEqual("", add.Device); + Assert.AreEqual(1, add.NumOutputs); + Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0)); + Assert.AreEqual(1, add.OutputListLength("sum")); + Assert.AreEqual(TF_Code.TF_OK, s.Code); + Assert.AreEqual(2, add.InputListLength("inputs")); + Assert.AreEqual(TF_Code.TF_OK, s.Code); + Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0)); + Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1)); + var add_in_0 = add.Input(0); + Assert.AreEqual(feed, add_in_0.oper); + Assert.AreEqual(0, add_in_0.index); + var add_in_1 = add.Input(1); + Assert.AreEqual(three, add_in_1.oper); + Assert.AreEqual(0, add_in_1.index); + Assert.AreEqual(0, add.OutputNumConsumers(0)); + Assert.AreEqual(0, add.NumControlInputs); + Assert.AreEqual(0, add.NumControlOutputs); + + Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); + Assert.AreEqual(DataType.DtInt32, attr_value.Type); + Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); + Assert.AreEqual(2, attr_value.I); + + // Placeholder oper now has a consumer. + Assert.AreEqual(1, feed.OutputNumConsumers(0)); + TF_Input[] feed_port = feed.OutputConsumers(0, 1); + Assert.AreEqual(1, feed_port.Length); + Assert.AreEqual(add, feed_port[0].oper); + Assert.AreEqual(0, feed_port[0].index); } } }