@@ -179,31 +179,50 @@ namespace Tensorflow | |||
TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | |||
for (int i = 0; i < num_return_outputs; i++) | |||
{ | |||
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(return_output_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||
var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i); | |||
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||
} | |||
return return_outputs; | |||
} | |||
public Operation[] ReturnOperations(IntPtr results) | |||
public unsafe Operation[] ReturnOperations(IntPtr results) | |||
{ | |||
IntPtr return_oper_handle = IntPtr.Zero; | |||
TF_Operation return_oper_handle = new TF_Operation(); | |||
int num_return_opers = 0; | |||
c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_opers, ref return_oper_handle); | |||
c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | |||
Operation[] return_opers = new Operation[num_return_opers]; | |||
for (int i = 0; i < num_return_opers; i++) | |||
{ | |||
// return_opers[i] = Marshal.PtrToStructure<TF_Output>(return_oper_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||
var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||
return_opers[i] = new Operation(*(IntPtr*)handle); | |||
} | |||
return return_opers; | |||
} | |||
public Operation OperationByName(string operName) | |||
{ | |||
return c_api.TF_GraphOperationByName(_handle, operName); | |||
} | |||
public Operation[] get_operations() | |||
{ | |||
return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
} | |||
public GraphDef ToGraphDef() | |||
{ | |||
var s = new Status(); | |||
var buffer = new Buffer(); | |||
c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
s.Check(); | |||
var def = GraphDef.Parser.ParseFrom(buffer); | |||
buffer.Dispose(); | |||
s.Dispose(); | |||
return def; | |||
} | |||
public void Dispose() | |||
{ | |||
c_api.TF_DeleteGraph(_handle); | |||
@@ -0,0 +1,17 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
[StructLayout(LayoutKind.Sequential)] | |||
public struct TF_ImportGraphDefResults | |||
{ | |||
public IntPtr return_tensors; | |||
public IntPtr return_nodes; | |||
public IntPtr missing_unused_key_names; | |||
public IntPtr missing_unused_key_indexes; | |||
public IntPtr missing_unused_key_names_data; | |||
} | |||
} |
@@ -22,6 +22,13 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts); | |||
/// <summary> | |||
/// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). | |||
/// </summary> | |||
/// <param name="results"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_DeleteImportGraphDefResults(IntPtr results); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | |||
@@ -91,9 +98,9 @@ namespace Tensorflow | |||
/// Write out a serialized representation of `graph` (as a GraphDef protocol | |||
/// message) to `output_graph_def` (allocated by TF_NewBuffer()). | |||
/// </summary> | |||
/// <param name="graph"></param> | |||
/// <param name="output_graph_def"></param> | |||
/// <param name="status"></param> | |||
/// <param name="graph">TF_Graph*</param> | |||
/// <param name="output_graph_def">TF_Buffer*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status); | |||
@@ -110,6 +117,15 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); | |||
/// <summary> | |||
/// Cause the imported graph to have a control dependency on `oper`. `oper` | |||
/// should exist in the graph being imported into. | |||
/// </summary> | |||
/// <param name="opts"></param> | |||
/// <param name="oper"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsAddControlDependency(IntPtr opts, IntPtr oper); | |||
/// <summary> | |||
/// Set any imported nodes with input `src_name:src_index` to have that input | |||
/// replaced with `dst`. `src_name` refers to a node in the graph to be imported, | |||
@@ -163,6 +179,18 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); | |||
/// <summary> | |||
/// Set any imported nodes with control input `src_name` to have that input | |||
/// replaced with `dst`. `src_name` refers to a node in the graph to be imported, | |||
/// `dst` references an operation already existing in the graph being imported | |||
/// into. `src_name` is copied and has no lifetime requirements. | |||
/// </summary> | |||
/// <param name="opts">TF_ImportGraphDefOptions*</param> | |||
/// <param name="src_name">const char*</param> | |||
/// <param name="dst">TF_Operation*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsRemapControlDependency(IntPtr opts, string src_name, IntPtr dst); | |||
/// <summary> | |||
/// Set the prefix to be prepended to the names of nodes in `graph_def` that will | |||
/// be imported into `graph`. `prefix` is copied and has no lifetime | |||
@@ -182,7 +210,7 @@ namespace Tensorflow | |||
/// <param name="num_opers">int*</param> | |||
/// <param name="opers">TF_Operation***</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref IntPtr opers); | |||
public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers); | |||
/// <summary> | |||
/// Fetches the return outputs requested via | |||
@@ -18,13 +18,16 @@ namespace Tensorflow | |||
public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); | |||
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); | |||
public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); | |||
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); | |||
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); | |||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | |||
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | |||
{ | |||
@@ -42,21 +45,41 @@ namespace Tensorflow | |||
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
public Operation[] ControlInputs(int max_control_inputs) | |||
public unsafe Operation[] GetControlInputs() | |||
{ | |||
var control_inputs = new Operation[NumControlInputs]; | |||
var control_input_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs); | |||
c_api.TF_OperationGetControlInputs(_handle, control_input_handles, max_control_inputs); | |||
if(NumControlInputs > 0) | |||
{ | |||
IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>()); | |||
c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); | |||
for (int i = 0; i < NumControlInputs; i++) | |||
{ | |||
var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i; | |||
control_inputs[i] = new Operation(*(IntPtr*)handle); | |||
} | |||
} | |||
return control_inputs; | |||
} | |||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||
public Operation[] ControlOutputs(int max_control_outputs) | |||
public unsafe Operation[] GetControlOutputs() | |||
{ | |||
var control_outputs = new Operation[NumControlOutputs]; | |||
var control_output_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | |||
c_api.TF_OperationGetControlInputs(_handle, control_output_handles, max_control_outputs); | |||
if(NumControlOutputs > 0) | |||
{ | |||
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>()); | |||
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); | |||
for (int i = 0; i < NumControlInputs; i++) | |||
{ | |||
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; | |||
control_outputs[i] = new Operation(*(IntPtr*)handle); | |||
} | |||
} | |||
return control_outputs; | |||
} | |||
@@ -0,0 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
[StructLayout(LayoutKind.Sequential)] | |||
public struct TF_Operation | |||
{ | |||
public IntPtr node; | |||
} | |||
} |
@@ -16,7 +16,7 @@ namespace Tensorflow | |||
/// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) | |||
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | |||
/// struct => struct (TF_Output output) => (TF_Output output) | |||
/// struct* => struct (TF_Output* output) => (TF_Output[] output) | |||
/// struct* => struct[] (TF_Output* output) => (TF_Output[] output) | |||
/// struct* => struct* for ref | |||
/// const char* => string | |||
/// int32_t => int | |||
@@ -228,9 +228,9 @@ namespace TensorFlowNET.UnitTest | |||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
Operation scalar = c_api.TF_GraphOperationByName(graph, "imported/scalar"); | |||
Operation feed = c_api.TF_GraphOperationByName(graph, "imported/feed"); | |||
Operation neg = c_api.TF_GraphOperationByName(graph, "imported/neg"); | |||
Operation scalar = graph.OperationByName("imported/scalar"); | |||
Operation feed = graph.OperationByName("imported/feed"); | |||
Operation neg = graph.OperationByName("imported/neg"); | |||
// Test basic structure of the imported graph. | |||
EXPECT_EQ(0, scalar.NumInputs); | |||
@@ -243,19 +243,19 @@ namespace TensorFlowNET.UnitTest | |||
// Test that we can't see control edges involving the source and sink nodes. | |||
EXPECT_EQ(0, scalar.NumControlInputs); | |||
EXPECT_EQ(0, scalar.ControlInputs(100).Length); | |||
EXPECT_EQ(0, scalar.GetControlInputs().Length); | |||
EXPECT_EQ(0, scalar.NumControlOutputs); | |||
EXPECT_EQ(0, scalar.ControlOutputs(100).Length); | |||
EXPECT_EQ(0, scalar.GetControlOutputs().Length); | |||
EXPECT_EQ(0, feed.NumControlInputs); | |||
EXPECT_EQ(0, feed.ControlInputs(100).Length); | |||
EXPECT_EQ(0, feed.GetControlInputs().Length); | |||
EXPECT_EQ(0, feed.NumControlOutputs); | |||
EXPECT_EQ(0, feed.ControlOutputs(100).Length); | |||
EXPECT_EQ(0, feed.GetControlOutputs().Length); | |||
EXPECT_EQ(0, neg.NumControlInputs); | |||
EXPECT_EQ(0, neg.ControlInputs(100).Length); | |||
EXPECT_EQ(0, neg.GetControlInputs().Length); | |||
EXPECT_EQ(0, neg.NumControlOutputs); | |||
EXPECT_EQ(0, neg.ControlOutputs(100).Length); | |||
EXPECT_EQ(0, neg.GetControlOutputs().Length); | |||
// Import it again, with an input mapping, return outputs, and a return | |||
// operation, into the same graph. | |||
@@ -271,9 +271,9 @@ namespace TensorFlowNET.UnitTest | |||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
Operation scalar2 = c_api.TF_GraphOperationByName(graph, "imported2/scalar"); | |||
Operation feed2 = c_api.TF_GraphOperationByName(graph, "imported2/feed"); | |||
Operation neg2 = c_api.TF_GraphOperationByName(graph, "imported2/neg"); | |||
Operation scalar2 = graph.OperationByName("imported2/scalar"); | |||
Operation feed2 = graph.OperationByName("imported2/feed"); | |||
Operation neg2 = graph.OperationByName("imported2/neg"); | |||
// Check input mapping | |||
neg_input = neg.Input(0); | |||
@@ -289,8 +289,72 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(0, return_outputs[1].index); | |||
// Check return operation | |||
var num_return_opers = graph.ReturnOperations(results); | |||
ASSERT_EQ(1, num_return_opers); | |||
var return_opers = graph.ReturnOperations(results); | |||
ASSERT_EQ(1, return_opers.Length); | |||
EXPECT_EQ(scalar2, return_opers[0]); // not remapped | |||
c_api.TF_DeleteImportGraphDefResults(results); | |||
// Import again, with control dependencies, into the same graph. | |||
c_api.TF_DeleteImportGraphDefOptions(opts); | |||
opts = c_api.TF_NewImportGraphDefOptions(); | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | |||
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | |||
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
var scalar3 = graph.OperationByName("imported3/scalar"); | |||
var feed3 = graph.OperationByName("imported3/feed"); | |||
var neg3 = graph.OperationByName("imported3/neg"); | |||
ASSERT_TRUE(scalar3 != IntPtr.Zero); | |||
ASSERT_TRUE(feed3 != IntPtr.Zero); | |||
ASSERT_TRUE(neg3 != IntPtr.Zero); | |||
// Check that newly-imported scalar and feed have control deps (neg3 will | |||
// inherit them from input) | |||
var control_inputs = scalar3.GetControlInputs(); | |||
ASSERT_EQ(2, scalar3.NumControlInputs); | |||
EXPECT_EQ(feed, control_inputs[0]); | |||
EXPECT_EQ(feed2, control_inputs[1]); | |||
control_inputs = feed3.GetControlInputs(); | |||
ASSERT_EQ(2, feed3.NumControlInputs); | |||
EXPECT_EQ(feed, control_inputs[0]); | |||
EXPECT_EQ(feed2, control_inputs[1]); | |||
// Export to a graph def so we can import a graph with control dependencies | |||
graph_def.Dispose(); | |||
graph_def = new Buffer(); | |||
c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
// Import again, with remapped control dependency, into the same graph | |||
c_api.TF_DeleteImportGraphDefOptions(opts); | |||
opts = c_api.TF_NewImportGraphDefOptions(); | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | |||
c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
var scalar4 = graph.OperationByName("imported4/imported3/scalar"); | |||
var feed4 = graph.OperationByName("imported4/imported2/feed"); | |||
// Check that imported `imported3/scalar` has remapped control dep from | |||
// original graph and imported control dep | |||
control_inputs = scalar4.GetControlInputs(); | |||
ASSERT_EQ(2, scalar4.NumControlInputs); | |||
EXPECT_EQ(feed, control_inputs[0]); | |||
EXPECT_EQ(feed4, control_inputs[1]); | |||
c_api.TF_DeleteImportGraphDefOptions(opts); | |||
c_api.TF_DeleteBuffer(graph_def); | |||
// Can add nodes to the imported graph without trouble. | |||
c_test_util.Add(feed, scalar, graph, s); | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
//graph.Dispose(); | |||
s.Dispose(); | |||
} | |||
} | |||
} |