@@ -0,0 +1,21 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public partial class Graph | |||
{ | |||
public Buffer ToGraphDef(Status s) | |||
{ | |||
var buffer = new Buffer(); | |||
c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
s.Check(); | |||
// var def = GraphDef.Parser.ParseFrom(buffer); | |||
// buffer.Dispose(); | |||
return buffer; | |||
} | |||
} | |||
} |
@@ -0,0 +1,28 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public partial class Graph | |||
{ | |||
public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s) | |||
{ | |||
var num_return_outputs = opts.NumReturnOutputs; | |||
var return_outputs = new TF_Output[num_return_outputs]; | |||
TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>() * 2); | |||
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||
for (int i = 0; i < num_return_outputs; i++) | |||
{ | |||
var handle = return_output_handle + i * Marshal.SizeOf<TF_Output>(); | |||
return_outputs[i] = new TF_Output((*handle).oper, (*handle).index); | |||
} | |||
return return_outputs; | |||
} | |||
} | |||
} |
@@ -13,7 +13,7 @@ namespace Tensorflow | |||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
/// https://www.tensorflow.org/guide/graphs | |||
/// </summary> | |||
public class Graph : IDisposable | |||
public partial class Graph : IDisposable | |||
{ | |||
private IntPtr _handle; | |||
private Dictionary<int, Operation> _nodes_by_id; | |||
@@ -211,18 +211,6 @@ namespace Tensorflow | |||
return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
} | |||
public GraphDef ToGraphDef() | |||
{ | |||
var s = new Status(); | |||
var buffer = new Buffer(); | |||
c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
s.Check(); | |||
var def = GraphDef.Parser.ParseFrom(buffer); | |||
buffer.Dispose(); | |||
s.Dispose(); | |||
return def; | |||
} | |||
public void Dispose() | |||
{ | |||
c_api.TF_DeleteGraph(_handle); | |||
@@ -0,0 +1,35 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class ImportGraphDefOptions : IDisposable | |||
{ | |||
private IntPtr _handle; | |||
public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||
public ImportGraphDefOptions() | |||
{ | |||
_handle = c_api.TF_NewImportGraphDefOptions(); | |||
} | |||
public ImportGraphDefOptions(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
} | |||
public void AddReturnOutput(string name, int index) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
} | |||
public void Dispose() | |||
{ | |||
c_api.TF_DeleteImportGraphDefOptions(_handle); | |||
} | |||
public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | |||
public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); | |||
} | |||
} |
@@ -45,6 +45,24 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||
/// <summary> | |||
/// Import the graph serialized in `graph_def` into `graph`. | |||
/// Convenience function for when only return outputs are needed. | |||
/// | |||
/// `num_return_outputs` must be the number of return outputs added (i.e. the | |||
/// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If | |||
/// `num_return_outputs` is non-zero, `return_outputs` must be of length | |||
/// `num_return_outputs`. Otherwise it can be null. | |||
/// </summary> | |||
/// <param name="graph">TF_Graph* graph</param> | |||
/// <param name="graph_def">const TF_Buffer*</param> | |||
/// <param name="options">const TF_ImportGraphDefOptions*</param> | |||
/// <param name="return_outputs">TF_Output*</param> | |||
/// <param name="num_return_outputs">int</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status); | |||
/// <summary> | |||
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | |||
/// a bad status on error. Otherwise, returns a populated | |||
@@ -357,10 +357,55 @@ namespace TensorFlowNET.UnitTest | |||
s.Dispose(); | |||
} | |||
/// <summary> | |||
/// Port from c_api_test.cc | |||
/// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)` | |||
/// </summary> | |||
[TestMethod] | |||
public void c_api_ImportGraphDef_WithReturnOutputs() | |||
{ | |||
var s = new Status(); | |||
var graph = new Graph(); | |||
// Create a graph with two nodes: x and 3 | |||
var feed = c_test_util.Placeholder(graph, s); | |||
EXPECT_EQ(feed, graph.OperationByName("feed")); | |||
var scalar = c_test_util.ScalarConst(3, graph, s); | |||
EXPECT_EQ(scalar, graph.OperationByName("scalar")); | |||
var neg = c_test_util.Neg(scalar, graph, s); | |||
EXPECT_EQ(neg, graph.OperationByName("neg")); | |||
// Export to a GraphDef. | |||
var graph_def = graph.ToGraphDef(s); | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
// Import it in a fresh graph with return outputs. | |||
graph.Dispose(); | |||
graph = new Graph(); | |||
var opts = new ImportGraphDefOptions(); | |||
opts.AddReturnOutput("feed", 0); | |||
opts.AddReturnOutput("scalar", 0); | |||
EXPECT_EQ(2, opts.NumReturnOutputs); | |||
var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
scalar = graph.OperationByName("scalar"); | |||
feed = graph.OperationByName("feed"); | |||
neg = graph.OperationByName("neg"); | |||
ASSERT_TRUE(scalar != IntPtr.Zero); | |||
ASSERT_TRUE(feed != IntPtr.Zero); | |||
ASSERT_TRUE(neg != IntPtr.Zero); | |||
// Check return outputs | |||
EXPECT_EQ(feed, return_outputs[0].oper); | |||
EXPECT_EQ(0, return_outputs[0].index); | |||
EXPECT_EQ(scalar, return_outputs[1].oper); | |||
EXPECT_EQ(0, return_outputs[1].index); | |||
opts.Dispose(); | |||
graph_def.Dispose(); | |||
graph.Dispose(); | |||
s.Dispose(); | |||
} | |||
} | |||
} |