@@ -126,7 +126,7 @@ namespace Tensorflow | |||||
} | } | ||||
else if (tensor_or_op is Operation) | else if (tensor_or_op is Operation) | ||||
{ | { | ||||
return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); | |||||
return !_unfetchable_ops.Contains((tensor_or_op as Operation).Name); | |||||
} | } | ||||
return false; | return false; | ||||
@@ -15,14 +15,31 @@ namespace Tensorflow | |||||
private Status status = new Status(); | private Status status = new Status(); | ||||
public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
public string optype => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||||
public string device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||||
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | ||||
public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||||
public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); | |||||
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); | |||||
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); | |||||
public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); | |||||
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); | |||||
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); | |||||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | public int NumInputs => c_api.TF_OperationNumInputs(_handle); | ||||
public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||||
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | |||||
public TF_Input[] OutputConsumers(int index, int max_consumers) | |||||
{ | |||||
IntPtr handle = IntPtr.Zero; | |||||
int size = Marshal.SizeOf<TF_Input>(); | |||||
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers); | |||||
var consumers = new TF_Input[num]; | |||||
for(int i = 0; i < num; i++) | |||||
{ | |||||
consumers[0] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||||
} | |||||
return consumers; | |||||
} | |||||
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | ||||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | ||||
@@ -92,5 +109,16 @@ namespace Tensorflow | |||||
{ | { | ||||
return op._handle; | return op._handle; | ||||
} | } | ||||
public override bool Equals(object obj) | |||||
{ | |||||
switch (obj) | |||||
{ | |||||
case IntPtr val: | |||||
return val == _handle; | |||||
} | |||||
return base.Equals(obj); | |||||
} | |||||
} | } | ||||
} | } |
@@ -8,6 +8,12 @@ namespace Tensorflow | |||||
[StructLayout(LayoutKind.Sequential)] | [StructLayout(LayoutKind.Sequential)] | ||||
public struct TF_Input | public struct TF_Input | ||||
{ | { | ||||
public TF_Input(IntPtr oper, int index) | |||||
{ | |||||
this.oper = oper; | |||||
this.index = index; | |||||
} | |||||
public IntPtr oper; | public IntPtr oper; | ||||
public int index; | public int index; | ||||
} | } | ||||
@@ -49,6 +49,22 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | ||||
/// <summary> | |||||
/// TF_Output producer = TF_OperationInput(consumer); | |||||
/// There is an edge from producer.oper's output (given by | |||||
/// producer.index) to consumer.oper's input (given by consumer.index). | |||||
/// </summary> | |||||
/// <param name="oper_in"></param> | |||||
/// <returns></returns> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TF_Output TF_OperationInput(TF_Input oper_in); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern int TF_OperationInputListLength(IntPtr oper, string arg_name, IntPtr status); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern TF_DataType TF_OperationInputType(TF_Input oper_in); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_OperationName(IntPtr oper); | public static extern IntPtr TF_OperationName(IntPtr oper); | ||||
@@ -87,6 +103,17 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out); | public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out); | ||||
/// <summary> | |||||
/// Get list of all current consumers of a specific output of an | |||||
/// operation. | |||||
/// </summary> | |||||
/// <param name="oper_out"></param> | |||||
/// <param name="consumers"></param> | |||||
/// <param name="max_consumers"></param> | |||||
/// <returns></returns> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | ||||
@@ -43,7 +43,7 @@ namespace Tensorflow | |||||
{ | { | ||||
if (!graph.is_fetchable(op)) | if (!graph.is_fetchable(op)) | ||||
{ | { | ||||
throw new Exception($"Operation {op.name} has been marked as not fetchable."); | |||||
throw new Exception($"Operation {op.Name} has been marked as not fetchable."); | |||||
} | } | ||||
} | } | ||||
@@ -17,6 +17,7 @@ namespace Tensorflow | |||||
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | ||||
/// struct => struct (TF_Output output) => (TF_Output output) | /// struct => struct (TF_Output output) => (TF_Output output) | ||||
/// struct* => struct (TF_Output* output) => (TF_Output[] output) | /// struct* => struct (TF_Output* output) => (TF_Output[] output) | ||||
/// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[] | |||||
/// const char* => string | /// const char* => string | ||||
/// int32_t => int | /// int32_t => int | ||||
/// int64_t* => long[] | /// int64_t* => long[] | ||||
@@ -21,14 +21,14 @@ namespace TensorFlowNET.UnitTest | |||||
// Make a placeholder operation. | // Make a placeholder operation. | ||||
var feed = c_test_util.Placeholder(graph, s); | var feed = c_test_util.Placeholder(graph, s); | ||||
Assert.AreEqual("feed", feed.name); | |||||
Assert.AreEqual("Placeholder", feed.optype); | |||||
Assert.AreEqual("", feed.device); | |||||
Assert.AreEqual("feed", feed.Name); | |||||
Assert.AreEqual("Placeholder", feed.OpType); | |||||
Assert.AreEqual("", feed.Device); | |||||
Assert.AreEqual(1, feed.NumOutputs); | Assert.AreEqual(1, feed.NumOutputs); | ||||
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); | |||||
Assert.AreEqual(1, feed.OutputListLength); | |||||
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0)); | |||||
Assert.AreEqual(1, feed.OutputListLength("output")); | |||||
Assert.AreEqual(0, feed.NumInputs); | Assert.AreEqual(0, feed.NumInputs); | ||||
Assert.AreEqual(0, feed.NumConsumers); | |||||
Assert.AreEqual(0, feed.OutputNumConsumers(0)); | |||||
Assert.AreEqual(0, feed.NumControlInputs); | Assert.AreEqual(0, feed.NumControlInputs); | ||||
Assert.AreEqual(0, feed.NumControlOutputs); | Assert.AreEqual(0, feed.NumControlOutputs); | ||||
@@ -44,9 +44,45 @@ namespace TensorFlowNET.UnitTest | |||||
// Make a constant oper with the scalar "3". | // Make a constant oper with the scalar "3". | ||||
var three = c_test_util.ScalarConst(3, graph, s); | var three = c_test_util.ScalarConst(3, graph, s); | ||||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
// Add oper. | // Add oper. | ||||
var add = c_test_util.Add(feed, three, graph, s); | var add = c_test_util.Add(feed, three, graph, s); | ||||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
// Test TF_Operation*() query functions. | |||||
Assert.AreEqual("add", add.Name); | |||||
Assert.AreEqual("AddN", add.OpType); | |||||
Assert.AreEqual("", add.Device); | |||||
Assert.AreEqual(1, add.NumOutputs); | |||||
Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0)); | |||||
Assert.AreEqual(1, add.OutputListLength("sum")); | |||||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
Assert.AreEqual(2, add.InputListLength("inputs")); | |||||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0)); | |||||
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1)); | |||||
var add_in_0 = add.Input(0); | |||||
Assert.AreEqual(feed, add_in_0.oper); | |||||
Assert.AreEqual(0, add_in_0.index); | |||||
var add_in_1 = add.Input(1); | |||||
Assert.AreEqual(three, add_in_1.oper); | |||||
Assert.AreEqual(0, add_in_1.index); | |||||
Assert.AreEqual(0, add.OutputNumConsumers(0)); | |||||
Assert.AreEqual(0, add.NumControlInputs); | |||||
Assert.AreEqual(0, add.NumControlOutputs); | |||||
Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | |||||
Assert.AreEqual(DataType.DtInt32, attr_value.Type); | |||||
Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | |||||
Assert.AreEqual(2, attr_value.I); | |||||
// Placeholder oper now has a consumer. | |||||
Assert.AreEqual(1, feed.OutputNumConsumers(0)); | |||||
TF_Input[] feed_port = feed.OutputConsumers(0, 1); | |||||
Assert.AreEqual(1, feed_port.Length); | |||||
Assert.AreEqual(add, feed_port[0].oper); | |||||
Assert.AreEqual(0, feed_port[0].index); | |||||
} | } | ||||
} | } | ||||
} | } |