Browse Source

Finished TEST(CAPI, ImportGraphDef_WithReturnOutputs).

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
a873bbc69b
6 changed files with 148 additions and 13 deletions
  1. +21
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Export.cs
  2. +28
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  3. +1
    -13
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +35
    -0
      src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
  5. +18
    -0
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  6. +45
    -0
      test/TensorFlowNET.UnitTest/GraphTest.cs

+ 21
- 0
src/TensorFlowNET.Core/Graphs/Graph.Export.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public partial class Graph
{
public Buffer ToGraphDef(Status s)
{
var buffer = new Buffer();
c_api.TF_GraphToGraphDef(_handle, buffer, s);
s.Check();
// var def = GraphDef.Parser.ParseFrom(buffer);
// buffer.Dispose();

return buffer;
}
}
}

+ 28
- 0
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -0,0 +1,28 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public partial class Graph
{
public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s)
{
var num_return_outputs = opts.NumReturnOutputs;
var return_outputs = new TF_Output[num_return_outputs];
TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>() * 2);

c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);
for (int i = 0; i < num_return_outputs; i++)
{
var handle = return_output_handle + i * Marshal.SizeOf<TF_Output>();

return_outputs[i] = new TF_Output((*handle).oper, (*handle).index);
}

return return_outputs;
}
}
}

+ 1
- 13
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
/// https://www.tensorflow.org/guide/graphs
/// </summary>
public class Graph : IDisposable
public partial class Graph : IDisposable
{
private IntPtr _handle;
private Dictionary<int, Operation> _nodes_by_id;
@@ -211,18 +211,6 @@ namespace Tensorflow
return _nodes_by_name.Values.Select(x => x).ToArray();
}

public GraphDef ToGraphDef()
{
var s = new Status();
var buffer = new Buffer();
c_api.TF_GraphToGraphDef(_handle, buffer, s);
s.Check();
var def = GraphDef.Parser.ParseFrom(buffer);
buffer.Dispose();
s.Dispose();
return def;
}

public void Dispose()
{
c_api.TF_DeleteGraph(_handle);


+ 35
- 0
src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class ImportGraphDefOptions : IDisposable
{
private IntPtr _handle;
public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle);

public ImportGraphDefOptions()
{
_handle = c_api.TF_NewImportGraphDefOptions();
}

public ImportGraphDefOptions(IntPtr handle)
{
_handle = handle;
}

public void AddReturnOutput(string name, int index)
{
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);
}

public void Dispose()
{
c_api.TF_DeleteImportGraphDefOptions(_handle);
}

public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle;
public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle);
}
}

+ 18
- 0
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -45,6 +45,24 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status);

/// <summary>
/// Import the graph serialized in `graph_def` into `graph`.
/// Convenience function for when only return outputs are needed.
///
/// `num_return_outputs` must be the number of return outputs added (i.e. the
/// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If
/// `num_return_outputs` is non-zero, `return_outputs` must be of length
/// `num_return_outputs`. Otherwise it can be null.
/// </summary>
/// <param name="graph">TF_Graph* graph</param>
/// <param name="graph_def">const TF_Buffer*</param>
/// <param name="options">const TF_ImportGraphDefOptions*</param>
/// <param name="return_outputs">TF_Output*</param>
/// <param name="num_return_outputs">int</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status);

/// <summary>
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and
/// a bad status on error. Otherwise, returns a populated


+ 45
- 0
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -357,10 +357,55 @@ namespace TensorFlowNET.UnitTest
s.Dispose();
}

/// <summary>
/// Port from c_api_test.cc
/// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)`
/// </summary>
[TestMethod]
public void c_api_ImportGraphDef_WithReturnOutputs()
{
var s = new Status();
var graph = new Graph();

// Create a graph with two nodes: x and 3
var feed = c_test_util.Placeholder(graph, s);
EXPECT_EQ(feed, graph.OperationByName("feed"));
var scalar = c_test_util.ScalarConst(3, graph, s);
EXPECT_EQ(scalar, graph.OperationByName("scalar"));
var neg = c_test_util.Neg(scalar, graph, s);
EXPECT_EQ(neg, graph.OperationByName("neg"));

// Export to a GraphDef.
var graph_def = graph.ToGraphDef(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Import it in a fresh graph with return outputs.
graph.Dispose();
graph = new Graph();
var opts = new ImportGraphDefOptions();
opts.AddReturnOutput("feed", 0);
opts.AddReturnOutput("scalar", 0);
EXPECT_EQ(2, opts.NumReturnOutputs);
var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

scalar = graph.OperationByName("scalar");
feed = graph.OperationByName("feed");
neg = graph.OperationByName("neg");
ASSERT_TRUE(scalar != IntPtr.Zero);
ASSERT_TRUE(feed != IntPtr.Zero);
ASSERT_TRUE(neg != IntPtr.Zero);

// Check return outputs
EXPECT_EQ(feed, return_outputs[0].oper);
EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper);
EXPECT_EQ(0, return_outputs[1].index);

opts.Dispose();
graph_def.Dispose();
graph.Dispose();
s.Dispose();
}
}
}

Loading…
Cancel
Save