Browse Source

OutputConsumers

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
7910b79258
7 changed files with 112 additions and 14 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +34
    -6
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +6
    -0
      src/TensorFlowNET.Core/Operations/TF_Input.cs
  4. +27
    -0
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  6. +1
    -0
      src/TensorFlowNET.Core/c_api.cs
  7. +42
    -6
      test/TensorFlowNET.UnitTest/GraphTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -126,7 +126,7 @@ namespace Tensorflow
}
else if (tensor_or_op is Operation)
{
return !_unfetchable_ops.Contains((tensor_or_op as Operation).name);
return !_unfetchable_ops.Contains((tensor_or_op as Operation).Name);
}

return false;


+ 34
- 6
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -15,14 +15,31 @@ namespace Tensorflow

private Status status = new Status();

public string name => c_api.StringPiece(c_api.TF_OperationName(_handle));
public string optype => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0));
public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status);
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);
public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0));
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
public TF_Input[] OutputConsumers(int index, int max_consumers)
{
IntPtr handle = IntPtr.Zero;
int size = Marshal.SizeOf<TF_Input>();
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers);
var consumers = new TF_Input[num];
for(int i = 0; i < num; i++)
{
consumers[0] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
}

return consumers;
}

public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);

@@ -92,5 +109,16 @@ namespace Tensorflow
{
return op._handle;
}

public override bool Equals(object obj)
{
switch (obj)
{
case IntPtr val:
return val == _handle;
}

return base.Equals(obj);
}
}
}

+ 6
- 0
src/TensorFlowNET.Core/Operations/TF_Input.cs View File

@@ -8,6 +8,12 @@ namespace Tensorflow
[StructLayout(LayoutKind.Sequential)]
public struct TF_Input
{
public TF_Input(IntPtr oper, int index)
{
this.oper = oper;
this.index = index;
}

public IntPtr oper;
public int index;
}


+ 27
- 0
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -49,6 +49,22 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status);

/// <summary>
/// TF_Output producer = TF_OperationInput(consumer);
/// There is an edge from producer.oper's output (given by
/// producer.index) to consumer.oper's input (given by consumer.index).
/// </summary>
/// <param name="oper_in"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern TF_Output TF_OperationInput(TF_Input oper_in);

[DllImport(TensorFlowLibName)]
public static extern int TF_OperationInputListLength(IntPtr oper, string arg_name, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_OperationInputType(TF_Input oper_in);

[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_OperationName(IntPtr oper);

@@ -87,6 +103,17 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out);

/// <summary>
/// Get list of all current consumers of a specific output of an
/// operation.
/// </summary>
/// <param name="oper_out"></param>
/// <param name="consumers"></param>
/// <param name="max_consumers"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers);

[DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);



+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
{
if (!graph.is_fetchable(op))
{
throw new Exception($"Operation {op.name} has been marked as not fetchable.");
throw new Exception($"Operation {op.Name} has been marked as not fetchable.");
}
}



+ 1
- 0
src/TensorFlowNET.Core/c_api.cs View File

@@ -17,6 +17,7 @@ namespace Tensorflow
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph)
/// struct => struct (TF_Output output) => (TF_Output output)
/// struct* => struct (TF_Output* output) => (TF_Output[] output)
/// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[]
/// const char* => string
/// int32_t => int
/// int64_t* => long[]


+ 42
- 6
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -21,14 +21,14 @@ namespace TensorFlowNET.UnitTest

// Make a placeholder operation.
var feed = c_test_util.Placeholder(graph, s);
Assert.AreEqual("feed", feed.name);
Assert.AreEqual("Placeholder", feed.optype);
Assert.AreEqual("", feed.device);
Assert.AreEqual("feed", feed.Name);
Assert.AreEqual("Placeholder", feed.OpType);
Assert.AreEqual("", feed.Device);
Assert.AreEqual(1, feed.NumOutputs);
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType);
Assert.AreEqual(1, feed.OutputListLength);
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0));
Assert.AreEqual(1, feed.OutputListLength("output"));
Assert.AreEqual(0, feed.NumInputs);
Assert.AreEqual(0, feed.NumConsumers);
Assert.AreEqual(0, feed.OutputNumConsumers(0));
Assert.AreEqual(0, feed.NumControlInputs);
Assert.AreEqual(0, feed.NumControlOutputs);

@@ -44,9 +44,45 @@ namespace TensorFlowNET.UnitTest

// Make a constant oper with the scalar "3".
var three = c_test_util.ScalarConst(3, graph, s);
Assert.AreEqual(TF_Code.TF_OK, s.Code);

// Add oper.
var add = c_test_util.Add(feed, three, graph, s);
Assert.AreEqual(TF_Code.TF_OK, s.Code);

// Test TF_Operation*() query functions.
Assert.AreEqual("add", add.Name);
Assert.AreEqual("AddN", add.OpType);
Assert.AreEqual("", add.Device);
Assert.AreEqual(1, add.NumOutputs);
Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0));
Assert.AreEqual(1, add.OutputListLength("sum"));
Assert.AreEqual(TF_Code.TF_OK, s.Code);
Assert.AreEqual(2, add.InputListLength("inputs"));
Assert.AreEqual(TF_Code.TF_OK, s.Code);
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0));
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1));
var add_in_0 = add.Input(0);
Assert.AreEqual(feed, add_in_0.oper);
Assert.AreEqual(0, add_in_0.index);
var add_in_1 = add.Input(1);
Assert.AreEqual(three, add_in_1.oper);
Assert.AreEqual(0, add_in_1.index);
Assert.AreEqual(0, add.OutputNumConsumers(0));
Assert.AreEqual(0, add.NumControlInputs);
Assert.AreEqual(0, add.NumControlOutputs);

Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s));
Assert.AreEqual(DataType.DtInt32, attr_value.Type);
Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s));
Assert.AreEqual(2, attr_value.I);

// Placeholder oper now has a consumer.
Assert.AreEqual(1, feed.OutputNumConsumers(0));
TF_Input[] feed_port = feed.OutputConsumers(0, 1);
Assert.AreEqual(1, feed_port.Length);
Assert.AreEqual(add, feed_port[0].oper);
Assert.AreEqual(0, feed_port[0].index);
}
}
}

Loading…
Cancel
Save