@@ -0,0 +1,21 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class Graph | |||||
{ | |||||
public Buffer ToGraphDef(Status s) | |||||
{ | |||||
var buffer = new Buffer(); | |||||
c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||||
s.Check(); | |||||
// var def = GraphDef.Parser.ParseFrom(buffer); | |||||
// buffer.Dispose(); | |||||
return buffer; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,28 @@ | |||||
using Google.Protobuf; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class Graph | |||||
{ | |||||
public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s) | |||||
{ | |||||
var num_return_outputs = opts.NumReturnOutputs; | |||||
var return_outputs = new TF_Output[num_return_outputs]; | |||||
TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>() * 2); | |||||
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||||
for (int i = 0; i < num_return_outputs; i++) | |||||
{ | |||||
var handle = return_output_handle + i * Marshal.SizeOf<TF_Output>(); | |||||
return_outputs[i] = new TF_Output((*handle).oper, (*handle).index); | |||||
} | |||||
return return_outputs; | |||||
} | |||||
} | |||||
} |
@@ -13,7 +13,7 @@ namespace Tensorflow | |||||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | ||||
/// https://www.tensorflow.org/guide/graphs | /// https://www.tensorflow.org/guide/graphs | ||||
/// </summary> | /// </summary> | ||||
public class Graph : IDisposable | |||||
public partial class Graph : IDisposable | |||||
{ | { | ||||
private IntPtr _handle; | private IntPtr _handle; | ||||
private Dictionary<int, Operation> _nodes_by_id; | private Dictionary<int, Operation> _nodes_by_id; | ||||
@@ -211,18 +211,6 @@ namespace Tensorflow | |||||
return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
} | } | ||||
public GraphDef ToGraphDef() | |||||
{ | |||||
var s = new Status(); | |||||
var buffer = new Buffer(); | |||||
c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||||
s.Check(); | |||||
var def = GraphDef.Parser.ParseFrom(buffer); | |||||
buffer.Dispose(); | |||||
s.Dispose(); | |||||
return def; | |||||
} | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
c_api.TF_DeleteGraph(_handle); | c_api.TF_DeleteGraph(_handle); | ||||
@@ -0,0 +1,35 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class ImportGraphDefOptions : IDisposable | |||||
{ | |||||
private IntPtr _handle; | |||||
public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||||
public ImportGraphDefOptions() | |||||
{ | |||||
_handle = c_api.TF_NewImportGraphDefOptions(); | |||||
} | |||||
public ImportGraphDefOptions(IntPtr handle) | |||||
{ | |||||
_handle = handle; | |||||
} | |||||
public void AddReturnOutput(string name, int index) | |||||
{ | |||||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
c_api.TF_DeleteImportGraphDefOptions(_handle); | |||||
} | |||||
public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | |||||
public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); | |||||
} | |||||
} |
@@ -45,6 +45,24 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | ||||
/// <summary> | |||||
/// Import the graph serialized in `graph_def` into `graph`. | |||||
/// Convenience function for when only return outputs are needed. | |||||
/// | |||||
/// `num_return_outputs` must be the number of return outputs added (i.e. the | |||||
/// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If | |||||
/// `num_return_outputs` is non-zero, `return_outputs` must be of length | |||||
/// `num_return_outputs`. Otherwise it can be null. | |||||
/// </summary> | |||||
/// <param name="graph">TF_Graph* graph</param> | |||||
/// <param name="graph_def">const TF_Buffer*</param> | |||||
/// <param name="options">const TF_ImportGraphDefOptions*</param> | |||||
/// <param name="return_outputs">TF_Output*</param> | |||||
/// <param name="num_return_outputs">int</param> | |||||
/// <param name="status">TF_Status*</param> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status); | |||||
/// <summary> | /// <summary> | ||||
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | ||||
/// a bad status on error. Otherwise, returns a populated | /// a bad status on error. Otherwise, returns a populated | ||||
@@ -357,10 +357,55 @@ namespace TensorFlowNET.UnitTest | |||||
s.Dispose(); | s.Dispose(); | ||||
} | } | ||||
/// <summary> | |||||
/// Port from c_api_test.cc | |||||
/// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)` | |||||
/// </summary> | |||||
[TestMethod] | [TestMethod] | ||||
public void c_api_ImportGraphDef_WithReturnOutputs() | public void c_api_ImportGraphDef_WithReturnOutputs() | ||||
{ | { | ||||
var s = new Status(); | |||||
var graph = new Graph(); | |||||
// Create a graph with two nodes: x and 3 | |||||
var feed = c_test_util.Placeholder(graph, s); | |||||
EXPECT_EQ(feed, graph.OperationByName("feed")); | |||||
var scalar = c_test_util.ScalarConst(3, graph, s); | |||||
EXPECT_EQ(scalar, graph.OperationByName("scalar")); | |||||
var neg = c_test_util.Neg(scalar, graph, s); | |||||
EXPECT_EQ(neg, graph.OperationByName("neg")); | |||||
// Export to a GraphDef. | |||||
var graph_def = graph.ToGraphDef(s); | |||||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
// Import it in a fresh graph with return outputs. | |||||
graph.Dispose(); | |||||
graph = new Graph(); | |||||
var opts = new ImportGraphDefOptions(); | |||||
opts.AddReturnOutput("feed", 0); | |||||
opts.AddReturnOutput("scalar", 0); | |||||
EXPECT_EQ(2, opts.NumReturnOutputs); | |||||
var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); | |||||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
scalar = graph.OperationByName("scalar"); | |||||
feed = graph.OperationByName("feed"); | |||||
neg = graph.OperationByName("neg"); | |||||
ASSERT_TRUE(scalar != IntPtr.Zero); | |||||
ASSERT_TRUE(feed != IntPtr.Zero); | |||||
ASSERT_TRUE(neg != IntPtr.Zero); | |||||
// Check return outputs | |||||
EXPECT_EQ(feed, return_outputs[0].oper); | |||||
EXPECT_EQ(0, return_outputs[0].index); | |||||
EXPECT_EQ(scalar, return_outputs[1].oper); | |||||
EXPECT_EQ(0, return_outputs[1].index); | |||||
opts.Dispose(); | |||||
graph_def.Dispose(); | |||||
graph.Dispose(); | |||||
s.Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } |