Browse Source

Completed TEST(CAPI, ImportGraphDef) except graph.Dispose exception threw.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
99c748fa5a
7 changed files with 194 additions and 30 deletions
  1. +24
    -5
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +17
    -0
      src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs
  3. +32
    -4
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  4. +29
    -6
      src/TensorFlowNET.Core/Operations/Operation.cs
  5. +13
    -0
      src/TensorFlowNET.Core/Operations/TF_Operation.cs
  6. +1
    -1
      src/TensorFlowNET.Core/c_api.cs
  7. +78
    -14
      test/TensorFlowNET.UnitTest/GraphTest.cs

+ 24
- 5
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -179,31 +179,50 @@ namespace Tensorflow
TF_Output[] return_outputs = new TF_Output[num_return_outputs];
for (int i = 0; i < num_return_outputs; i++)
{
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(return_output_handle + (Marshal.SizeOf<TF_Output>() * i));
var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i);
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
}

return return_outputs;
}

public Operation[] ReturnOperations(IntPtr results)
public unsafe Operation[] ReturnOperations(IntPtr results)
{
IntPtr return_oper_handle = IntPtr.Zero;
TF_Operation return_oper_handle = new TF_Operation();
int num_return_opers = 0;
c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_opers, ref return_oper_handle);
c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle);
Operation[] return_opers = new Operation[num_return_opers];
for (int i = 0; i < num_return_opers; i++)
{
// return_opers[i] = Marshal.PtrToStructure<TF_Output>(return_oper_handle + (Marshal.SizeOf<TF_Output>() * i));
var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i;
return_opers[i] = new Operation(*(IntPtr*)handle);
}

return return_opers;
}

public Operation OperationByName(string operName)
{
return c_api.TF_GraphOperationByName(_handle, operName);
}

public Operation[] get_operations()
{
return _nodes_by_name.Values.Select(x => x).ToArray();
}

public GraphDef ToGraphDef()
{
var s = new Status();
var buffer = new Buffer();
c_api.TF_GraphToGraphDef(_handle, buffer, s);
s.Check();
var def = GraphDef.Parser.ParseFrom(buffer);
buffer.Dispose();
s.Dispose();
return def;
}

public void Dispose()
{
c_api.TF_DeleteGraph(_handle);


+ 17
- 0
src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs View File

@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
[StructLayout(LayoutKind.Sequential)]
public struct TF_ImportGraphDefResults
{
public IntPtr return_tensors;
public IntPtr return_nodes;
public IntPtr missing_unused_key_names;
public IntPtr missing_unused_key_indexes;
public IntPtr missing_unused_key_names_data;
}
}

+ 32
- 4
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -22,6 +22,13 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts);

/// <summary>
/// Deletes a results object returned by TF_GraphImportGraphDefWithResults().
/// </summary>
/// <param name="results"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_DeleteImportGraphDefResults(IntPtr results);

[DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status);

@@ -91,9 +98,9 @@ namespace Tensorflow
/// Write out a serialized representation of `graph` (as a GraphDef protocol
/// message) to `output_graph_def` (allocated by TF_NewBuffer()).
/// </summary>
/// <param name="graph"></param>
/// <param name="output_graph_def"></param>
/// <param name="status"></param>
/// <param name="graph">TF_Graph*</param>
/// <param name="output_graph_def">TF_Buffer*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status);
@@ -110,6 +117,15 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status);

/// <summary>
/// Cause the imported graph to have a control dependency on `oper`. `oper`
/// should exist in the graph being imported into.
/// </summary>
/// <param name="opts"></param>
/// <param name="oper"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddControlDependency(IntPtr opts, IntPtr oper);

/// <summary>
/// Set any imported nodes with input `src_name:src_index` to have that input
/// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
@@ -163,6 +179,18 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts);

/// <summary>
/// Set any imported nodes with control input `src_name` to have that input
/// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
/// `dst` references an operation already existing in the graph being imported
/// into. `src_name` is copied and has no lifetime requirements.
/// </summary>
/// <param name="opts">TF_ImportGraphDefOptions*</param>
/// <param name="src_name">const char*</param>
/// <param name="dst">TF_Operation*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsRemapControlDependency(IntPtr opts, string src_name, IntPtr dst);

/// <summary>
/// Set the prefix to be prepended to the names of nodes in `graph_def` that will
/// be imported into `graph`. `prefix` is copied and has no lifetime
@@ -182,7 +210,7 @@ namespace Tensorflow
/// <param name="num_opers">int*</param>
/// <param name="opers">TF_Operation***</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref IntPtr opers);
public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers);

/// <summary>
/// Fetches the return outputs requested via


+ 29
- 6
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -18,13 +18,16 @@ namespace Tensorflow
public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));

public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);

public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
public int NumInputs => c_api.TF_OperationNumInputs(_handle);

public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
{
@@ -42,21 +45,41 @@ namespace Tensorflow

public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);

public Operation[] ControlInputs(int max_control_inputs)
public unsafe Operation[] GetControlInputs()
{
var control_inputs = new Operation[NumControlInputs];
var control_input_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs);
c_api.TF_OperationGetControlInputs(_handle, control_input_handles, max_control_inputs);

if(NumControlInputs > 0)
{
IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>());
c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
for (int i = 0; i < NumControlInputs; i++)
{
var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i;
control_inputs[i] = new Operation(*(IntPtr*)handle);
}
}

return control_inputs;
}

public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);

public Operation[] ControlOutputs(int max_control_outputs)
public unsafe Operation[] GetControlOutputs()
{
var control_outputs = new Operation[NumControlOutputs];
var control_output_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
c_api.TF_OperationGetControlInputs(_handle, control_output_handles, max_control_outputs);

if(NumControlOutputs > 0)
{
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>());
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
for (int i = 0; i < NumControlInputs; i++)
{
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i;
control_outputs[i] = new Operation(*(IntPtr*)handle);
}
}

return control_outputs;
}



+ 13
- 0
src/TensorFlowNET.Core/Operations/TF_Operation.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
[StructLayout(LayoutKind.Sequential)]
public struct TF_Operation
{
public IntPtr node;
}
}

+ 1
- 1
src/TensorFlowNET.Core/c_api.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow
/// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op)
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph)
/// struct => struct (TF_Output output) => (TF_Output output)
/// struct* => struct (TF_Output* output) => (TF_Output[] output)
/// struct* => struct[] (TF_Output* output) => (TF_Output[] output)
/// struct* => struct* for ref
/// const char* => string
/// int32_t => int


+ 78
- 14
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -228,9 +228,9 @@ namespace TensorFlowNET.UnitTest
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

Operation scalar = c_api.TF_GraphOperationByName(graph, "imported/scalar");
Operation feed = c_api.TF_GraphOperationByName(graph, "imported/feed");
Operation neg = c_api.TF_GraphOperationByName(graph, "imported/neg");
Operation scalar = graph.OperationByName("imported/scalar");
Operation feed = graph.OperationByName("imported/feed");
Operation neg = graph.OperationByName("imported/neg");

// Test basic structure of the imported graph.
EXPECT_EQ(0, scalar.NumInputs);
@@ -243,19 +243,19 @@ namespace TensorFlowNET.UnitTest

// Test that we can't see control edges involving the source and sink nodes.
EXPECT_EQ(0, scalar.NumControlInputs);
EXPECT_EQ(0, scalar.ControlInputs(100).Length);
EXPECT_EQ(0, scalar.GetControlInputs().Length);
EXPECT_EQ(0, scalar.NumControlOutputs);
EXPECT_EQ(0, scalar.ControlOutputs(100).Length);
EXPECT_EQ(0, scalar.GetControlOutputs().Length);

EXPECT_EQ(0, feed.NumControlInputs);
EXPECT_EQ(0, feed.ControlInputs(100).Length);
EXPECT_EQ(0, feed.GetControlInputs().Length);
EXPECT_EQ(0, feed.NumControlOutputs);
EXPECT_EQ(0, feed.ControlOutputs(100).Length);
EXPECT_EQ(0, feed.GetControlOutputs().Length);

EXPECT_EQ(0, neg.NumControlInputs);
EXPECT_EQ(0, neg.ControlInputs(100).Length);
EXPECT_EQ(0, neg.GetControlInputs().Length);
EXPECT_EQ(0, neg.NumControlOutputs);
EXPECT_EQ(0, neg.ControlOutputs(100).Length);
EXPECT_EQ(0, neg.GetControlOutputs().Length);

// Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph.
@@ -271,9 +271,9 @@ namespace TensorFlowNET.UnitTest
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

Operation scalar2 = c_api.TF_GraphOperationByName(graph, "imported2/scalar");
Operation feed2 = c_api.TF_GraphOperationByName(graph, "imported2/feed");
Operation neg2 = c_api.TF_GraphOperationByName(graph, "imported2/neg");
Operation scalar2 = graph.OperationByName("imported2/scalar");
Operation feed2 = graph.OperationByName("imported2/feed");
Operation neg2 = graph.OperationByName("imported2/neg");

// Check input mapping
neg_input = neg.Input(0);
@@ -289,8 +289,72 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(0, return_outputs[1].index);

// Check return operation
var num_return_opers = graph.ReturnOperations(results);
ASSERT_EQ(1, num_return_opers);
var return_opers = graph.ReturnOperations(results);
ASSERT_EQ(1, return_opers.Length);
EXPECT_EQ(scalar2, return_opers[0]); // not remapped
c_api.TF_DeleteImportGraphDefResults(results);

// Import again, with control dependencies, into the same graph.
c_api.TF_DeleteImportGraphDefOptions(opts);
opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

var scalar3 = graph.OperationByName("imported3/scalar");
var feed3 = graph.OperationByName("imported3/feed");
var neg3 = graph.OperationByName("imported3/neg");
ASSERT_TRUE(scalar3 != IntPtr.Zero);
ASSERT_TRUE(feed3 != IntPtr.Zero);
ASSERT_TRUE(neg3 != IntPtr.Zero);

// Check that newly-imported scalar and feed have control deps (neg3 will
// inherit them from input)
var control_inputs = scalar3.GetControlInputs();
ASSERT_EQ(2, scalar3.NumControlInputs);
EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed2, control_inputs[1]);

control_inputs = feed3.GetControlInputs();
ASSERT_EQ(2, feed3.NumControlInputs);
EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed2, control_inputs[1]);

// Export to a graph def so we can import a graph with control dependencies
graph_def.Dispose();
graph_def = new Buffer();
c_api.TF_GraphToGraphDef(graph, graph_def, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

// Import again, with remapped control dependency, into the same graph
c_api.TF_DeleteImportGraphDefOptions(opts);
opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

var scalar4 = graph.OperationByName("imported4/imported3/scalar");
var feed4 = graph.OperationByName("imported4/imported2/feed");

// Check that imported `imported3/scalar` has remapped control dep from
// original graph and imported control dep
control_inputs = scalar4.GetControlInputs();
ASSERT_EQ(2, scalar4.NumControlInputs);
EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed4, control_inputs[1]);

c_api.TF_DeleteImportGraphDefOptions(opts);
c_api.TF_DeleteBuffer(graph_def);

// Can add nodes to the imported graph without trouble.
c_test_util.Add(feed, scalar, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

//graph.Dispose();
s.Dispose();
}
}
}

Loading…
Cancel
Save