@@ -126,7 +126,7 @@ namespace Tensorflow | |||
} | |||
else if (tensor_or_op is Operation) | |||
{ | |||
return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); | |||
return !_unfetchable_ops.Contains((tensor_or_op as Operation).Name); | |||
} | |||
return false; | |||
@@ -15,14 +15,31 @@ namespace Tensorflow | |||
private Status status = new Status(); | |||
public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||
public string optype => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||
public string device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||
public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||
public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); | |||
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); | |||
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); | |||
public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); | |||
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); | |||
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); | |||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | |||
public TF_Input[] OutputConsumers(int index, int max_consumers) | |||
{ | |||
IntPtr handle = IntPtr.Zero; | |||
int size = Marshal.SizeOf<TF_Input>(); | |||
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers); | |||
var consumers = new TF_Input[num]; | |||
for(int i = 0; i < num; i++) | |||
{ | |||
consumers[0] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||
} | |||
return consumers; | |||
} | |||
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||
@@ -92,5 +109,16 @@ namespace Tensorflow | |||
{ | |||
return op._handle; | |||
} | |||
public override bool Equals(object obj) | |||
{ | |||
switch (obj) | |||
{ | |||
case IntPtr val: | |||
return val == _handle; | |||
} | |||
return base.Equals(obj); | |||
} | |||
} | |||
} |
@@ -8,6 +8,12 @@ namespace Tensorflow | |||
[StructLayout(LayoutKind.Sequential)] | |||
public struct TF_Input | |||
{ | |||
public TF_Input(IntPtr oper, int index) | |||
{ | |||
this.oper = oper; | |||
this.index = index; | |||
} | |||
public IntPtr oper; | |||
public int index; | |||
} | |||
@@ -49,6 +49,22 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | |||
/// <summary> | |||
/// TF_Output producer = TF_OperationInput(consumer); | |||
/// There is an edge from producer.oper's output (given by | |||
/// producer.index) to consumer.oper's input (given by consumer.index). | |||
/// </summary> | |||
/// <param name="oper_in"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_Output TF_OperationInput(TF_Input oper_in); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationInputListLength(IntPtr oper, string arg_name, IntPtr status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_DataType TF_OperationInputType(TF_Input oper_in); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_OperationName(IntPtr oper); | |||
@@ -87,6 +103,17 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out); | |||
/// <summary> | |||
/// Get list of all current consumers of a specific output of an | |||
/// operation. | |||
/// </summary> | |||
/// <param name="oper_out"></param> | |||
/// <param name="consumers"></param> | |||
/// <param name="max_consumers"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | |||
@@ -43,7 +43,7 @@ namespace Tensorflow | |||
{ | |||
if (!graph.is_fetchable(op)) | |||
{ | |||
throw new Exception($"Operation {op.name} has been marked as not fetchable."); | |||
throw new Exception($"Operation {op.Name} has been marked as not fetchable."); | |||
} | |||
} | |||
@@ -17,6 +17,7 @@ namespace Tensorflow | |||
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | |||
/// struct => struct (TF_Output output) => (TF_Output output) | |||
/// struct* => struct (TF_Output* output) => (TF_Output[] output) | |||
/// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[] | |||
/// const char* => string | |||
/// int32_t => int | |||
/// int64_t* => long[] | |||
@@ -21,14 +21,14 @@ namespace TensorFlowNET.UnitTest | |||
// Make a placeholder operation. | |||
var feed = c_test_util.Placeholder(graph, s); | |||
Assert.AreEqual("feed", feed.name); | |||
Assert.AreEqual("Placeholder", feed.optype); | |||
Assert.AreEqual("", feed.device); | |||
Assert.AreEqual("feed", feed.Name); | |||
Assert.AreEqual("Placeholder", feed.OpType); | |||
Assert.AreEqual("", feed.Device); | |||
Assert.AreEqual(1, feed.NumOutputs); | |||
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); | |||
Assert.AreEqual(1, feed.OutputListLength); | |||
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0)); | |||
Assert.AreEqual(1, feed.OutputListLength("output")); | |||
Assert.AreEqual(0, feed.NumInputs); | |||
Assert.AreEqual(0, feed.NumConsumers); | |||
Assert.AreEqual(0, feed.OutputNumConsumers(0)); | |||
Assert.AreEqual(0, feed.NumControlInputs); | |||
Assert.AreEqual(0, feed.NumControlOutputs); | |||
@@ -44,9 +44,45 @@ namespace TensorFlowNET.UnitTest | |||
// Make a constant oper with the scalar "3". | |||
var three = c_test_util.ScalarConst(3, graph, s); | |||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
// Add oper. | |||
var add = c_test_util.Add(feed, three, graph, s); | |||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
// Test TF_Operation*() query functions. | |||
Assert.AreEqual("add", add.Name); | |||
Assert.AreEqual("AddN", add.OpType); | |||
Assert.AreEqual("", add.Device); | |||
Assert.AreEqual(1, add.NumOutputs); | |||
Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0)); | |||
Assert.AreEqual(1, add.OutputListLength("sum")); | |||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
Assert.AreEqual(2, add.InputListLength("inputs")); | |||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0)); | |||
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1)); | |||
var add_in_0 = add.Input(0); | |||
Assert.AreEqual(feed, add_in_0.oper); | |||
Assert.AreEqual(0, add_in_0.index); | |||
var add_in_1 = add.Input(1); | |||
Assert.AreEqual(three, add_in_1.oper); | |||
Assert.AreEqual(0, add_in_1.index); | |||
Assert.AreEqual(0, add.OutputNumConsumers(0)); | |||
Assert.AreEqual(0, add.NumControlInputs); | |||
Assert.AreEqual(0, add.NumControlOutputs); | |||
Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | |||
Assert.AreEqual(DataType.DtInt32, attr_value.Type); | |||
Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | |||
Assert.AreEqual(2, attr_value.I); | |||
// Placeholder oper now has a consumer. | |||
Assert.AreEqual(1, feed.OutputNumConsumers(0)); | |||
TF_Input[] feed_port = feed.OutputConsumers(0, 1); | |||
Assert.AreEqual(1, feed_port.Length); | |||
Assert.AreEqual(add, feed_port[0].oper); | |||
Assert.AreEqual(0, feed_port[0].index); | |||
} | |||
} | |||
} |