Browse Source

TEST(CAPI, ImportGraphDef))

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
642217491a
7 changed files with 401 additions and 83 deletions
  1. +28
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +120
    -1
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  3. +17
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +36
    -2
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  5. +25
    -0
      test/TensorFlowNET.UnitTest/CApiTest.cs
  6. +156
    -61
      test/TensorFlowNET.UnitTest/GraphTest.cs
  7. +19
    -19
      test/TensorFlowNET.UnitTest/TensorTest.cs

+ 28
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -171,6 +171,34 @@ namespace Tensorflow
return $"{name}_{_names_in_use[name_key]}";
}

public TF_Output[] ReturnOutputs(IntPtr results)
{
IntPtr return_output_handle = IntPtr.Zero;
int num_return_outputs = 0;
c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle);
TF_Output[] return_outputs = new TF_Output[num_return_outputs];
for (int i = 0; i < num_return_outputs; i++)
{
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(return_output_handle + (Marshal.SizeOf<TF_Output>() * i));
}

return return_outputs;
}

public Operation[] ReturnOperations(IntPtr results)
{
IntPtr return_oper_handle = IntPtr.Zero;
int num_return_opers = 0;
c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_opers, ref return_oper_handle);
Operation[] return_opers = new Operation[num_return_opers];
for (int i = 0; i < num_return_opers; i++)
{
// return_opers[i] = Marshal.PtrToStructure<TF_Output>(return_oper_handle + (Marshal.SizeOf<TF_Output>() * i));
}

return return_opers;
}

public Operation[] get_operations()
{
return _nodes_by_name.Values.Select(x => x).ToArray();


+ 120
- 1
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -15,6 +15,13 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_DeleteGraph(IntPtr graph);

/// <summary>
///
/// </summary>
/// <param name="opts">TF_ImportGraphDefOptions*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts);

[DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status);

@@ -31,6 +38,29 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status);

/// <summary>
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and
/// a bad status on error. Otherwise, returns a populated
/// TF_ImportGraphDefResults instance. The returned instance must be deleted via
/// TF_DeleteImportGraphDefResults().
/// </summary>
/// <param name="graph">TF_Graph*</param>
/// <param name="graph_def">const TF_Buffer*</param>
/// <param name="options">const TF_ImportGraphDefOptions*</param>
/// <param name="status">TF_Status*</param>
/// <returns>TF_ImportGraphDefResults*</returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status);

/// <summary>
/// Import the graph serialized in `graph_def` into `graph`.
/// </summary>
/// <param name="graph">TF_Graph*</param>
/// <param name="graph_def">TF_Buffer*</param>
/// <param name="options">TF_ImportGraphDefOptions*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status);
/// <summary>
/// Iterate through the operations of a graph.
/// </summary>
@@ -80,7 +110,96 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status);

/// <summary>
/// Set any imported nodes with input `src_name:src_index` to have that input
/// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
/// `dst` references a node already existing in the graph being imported into.
/// `src_name` is copied and has no lifetime requirements.
/// </summary>
/// <param name="opts">TF_ImportGraphDefOptions*</param>
/// <param name="src_name">const char*</param>
/// <param name="src_index">int</param>
/// <param name="dst">TF_Output</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddInputMapping(IntPtr opts, string src_name, int src_index, TF_Output dst);

/// <summary>
/// Add an operation in `graph_def` to be returned via the `return_opers` output
/// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no
// lifetime requirements.
/// </summary>
/// <param name="opts">TF_ImportGraphDefOptions* opts</param>
/// <param name="oper_name">const char*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddReturnOperation(IntPtr opts, string oper_name);

/// <summary>
/// Add an output in `graph_def` to be returned via the `return_outputs` output
/// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input
/// mapping, the corresponding existing tensor in `graph` will be returned.
/// `oper_name` is copied and has no lifetime requirements.
/// </summary>
/// <param name="opts">TF_ImportGraphDefOptions*</param>
/// <param name="oper_name">const char*</param>
/// <param name="index">int</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddReturnOutput(IntPtr opts, string oper_name, int index);

/// <summary>
/// Returns the number of return operations added via
/// TF_ImportGraphDefOptionsAddReturnOperation().
/// </summary>
/// <param name="opts"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_ImportGraphDefOptionsNumReturnOperations(IntPtr opts);

/// <summary>
/// Returns the number of return outputs added via
/// TF_ImportGraphDefOptionsAddReturnOutput().
/// </summary>
/// <param name="opts">const TF_ImportGraphDefOptions*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts);

/// <summary>
/// Set the prefix to be prepended to the names of nodes in `graph_def` that will
/// be imported into `graph`. `prefix` is copied and has no lifetime
/// requirements.
/// </summary>
/// <param name="ops"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix);

/// <summary>
/// Fetches the return operations requested via
/// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched
/// operations is returned in `num_opers`. The array of return operations is
/// returned in `opers`. `*opers` is owned by and has the lifetime of `results`.
/// </summary>
/// <param name="results">TF_ImportGraphDefResults*</param>
/// <param name="num_opers">int*</param>
/// <param name="opers">TF_Operation***</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref IntPtr opers);

/// <summary>
/// Fetches the return outputs requested via
/// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is
/// returned in `num_outputs`. The array of return outputs is returned in
/// `outputs`. `*outputs` is owned by and has the lifetime of `results`.
/// </summary>
/// <param name="results">TF_ImportGraphDefResults* results</param>
/// <param name="num_outputs">int*</param>
/// <param name="outputs">TF_Output**</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs);

[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewGraph();

[DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_NewGraph();
public static extern IntPtr TF_NewImportGraphDefOptions();
}
}

+ 17
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -41,8 +41,25 @@ namespace Tensorflow
}

public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);

public Operation[] ControlInputs(int max_control_inputs)
{
var control_inputs = new Operation[NumControlInputs];
var control_input_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs);
c_api.TF_OperationGetControlInputs(_handle, control_input_handles, max_control_inputs);
return control_inputs;
}

public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);

public Operation[] ControlOutputs(int max_control_outputs)
{
var control_outputs = new Operation[NumControlOutputs];
var control_output_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
c_api.TF_OperationGetControlInputs(_handle, control_output_handles, max_control_outputs);
return control_outputs;
}

private Tensor[] _outputs;
public Tensor[] outputs => _outputs;
public Tensor[] inputs;


+ 36
- 2
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -49,6 +49,35 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status);

/// <summary>
/// Get list of all control inputs to an operation. `control_inputs` must
/// point to an array of length `max_control_inputs` (ideally set to
/// TF_OperationNumControlInputs(oper)). Returns the number of control
/// inputs (should match TF_OperationNumControlInputs(oper)).
/// </summary>
/// <param name="oper">TF_Operation*</param>
/// <param name="control_inputs">TF_Operation**</param>
/// <param name="max_control_inputs"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationGetControlInputs(IntPtr oper, IntPtr control_inputs, int max_control_inputs);

/// <summary>
/// Get the list of operations that have `*oper` as a control input.
/// `control_outputs` must point to an array of length at least
/// `max_control_outputs` (ideally set to
/// TF_OperationNumControlOutputs(oper)). Beware that a concurrent
/// modification of the graph can increase the number of control
/// outputs. Returns the number of control outputs (should match
/// TF_OperationNumControlOutputs(oper)).
/// </summary>
/// <param name="oper">TF_Operation*</param>
/// <param name="control_outputs">TF_Operation**</param>
/// <param name="max_control_outputs"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationGetControlOutputs(IntPtr oper, IntPtr control_outputs, int max_control_outputs);

/// <summary>
/// TF_Output producer = TF_OperationInput(consumer);
/// There is an edge from producer.oper's output (given by
@@ -105,14 +134,19 @@ namespace Tensorflow

/// <summary>
/// Get list of all current consumers of a specific output of an
/// operation.
/// operation. `consumers` must point to an array of length at least
/// `max_consumers` (ideally set to
/// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent
/// modification of the graph can increase the number of consumers of
/// an operation. Returns the number of output consumers (should match
/// TF_OperationOutputNumConsumers(oper_out)).
/// </summary>
/// <param name="oper_out"></param>
/// <param name="consumers"></param>
/// <param name="max_consumers"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers);
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, int max_consumers);

[DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);


+ 25
- 0
test/TensorFlowNET.UnitTest/CApiTest.cs View File

@@ -0,0 +1,25 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;

namespace TensorFlowNET.UnitTest
{
public class CApiTest
{
public void EXPECT_EQ(object expected, object actual)
{
Assert.AreEqual(expected, actual);
}

public void ASSERT_EQ(object expected, object actual)
{
Assert.AreEqual(expected, actual);
}

public void ASSERT_TRUE(bool condition)
{
Assert.IsTrue(condition);
}
}
}

+ 156
- 61
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -1,13 +1,15 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow;
using Buffer = Tensorflow.Buffer;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class GraphTest
public class GraphTest : CApiTest
{
/// <summary>
/// Port from c_api_test.cc
@@ -21,74 +23,74 @@ namespace TensorFlowNET.UnitTest

// Make a placeholder operation.
var feed = c_test_util.Placeholder(graph, s);
Assert.AreEqual("feed", feed.Name);
Assert.AreEqual("Placeholder", feed.OpType);
Assert.AreEqual("", feed.Device);
Assert.AreEqual(1, feed.NumOutputs);
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0));
Assert.AreEqual(1, feed.OutputListLength("output"));
Assert.AreEqual(0, feed.NumInputs);
Assert.AreEqual(0, feed.OutputNumConsumers(0));
Assert.AreEqual(0, feed.NumControlInputs);
Assert.AreEqual(0, feed.NumControlOutputs);
EXPECT_EQ("feed", feed.Name);
EXPECT_EQ("Placeholder", feed.OpType);
EXPECT_EQ("", feed.Device);
EXPECT_EQ(1, feed.NumOutputs);
EXPECT_EQ(TF_DataType.TF_INT32, feed.OutputType(0));
EXPECT_EQ(1, feed.OutputListLength("output"));
EXPECT_EQ(0, feed.NumInputs);
EXPECT_EQ(0, feed.OutputNumConsumers(0));
EXPECT_EQ(0, feed.NumControlInputs);
EXPECT_EQ(0, feed.NumControlOutputs);

AttrValue attr_value = null;
Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s));
Assert.AreEqual(attr_value.Type, DataType.DtInt32);
ASSERT_TRUE(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s));
EXPECT_EQ(attr_value.Type, DataType.DtInt32);

// Test not found errors in TF_Operation*() query functions.
Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code);
EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code);
Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message);
EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message);

// Make a constant oper with the scalar "3".
var three = c_test_util.ScalarConst(3, graph, s);
Assert.AreEqual(TF_Code.TF_OK, s.Code);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

// Add oper.
var add = c_test_util.Add(feed, three, graph, s);
Assert.AreEqual(TF_Code.TF_OK, s.Code);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

// Test TF_Operation*() query functions.
Assert.AreEqual("add", add.Name);
Assert.AreEqual("AddN", add.OpType);
Assert.AreEqual("", add.Device);
Assert.AreEqual(1, add.NumOutputs);
Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0));
Assert.AreEqual(1, add.OutputListLength("sum"));
Assert.AreEqual(TF_Code.TF_OK, s.Code);
Assert.AreEqual(2, add.InputListLength("inputs"));
Assert.AreEqual(TF_Code.TF_OK, s.Code);
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0));
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1));
EXPECT_EQ("add", add.Name);
EXPECT_EQ("AddN", add.OpType);
EXPECT_EQ("", add.Device);
EXPECT_EQ(1, add.NumOutputs);
EXPECT_EQ(TF_DataType.TF_INT32, add.OutputType(0));
EXPECT_EQ(1, add.OutputListLength("sum"));
EXPECT_EQ(TF_Code.TF_OK, s.Code);
EXPECT_EQ(2, add.InputListLength("inputs"));
EXPECT_EQ(TF_Code.TF_OK, s.Code);
EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(0));
EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(1));
var add_in_0 = add.Input(0);
Assert.AreEqual(feed, add_in_0.oper);
Assert.AreEqual(0, add_in_0.index);
EXPECT_EQ(feed, add_in_0.oper);
EXPECT_EQ(0, add_in_0.index);
var add_in_1 = add.Input(1);
Assert.AreEqual(three, add_in_1.oper);
Assert.AreEqual(0, add_in_1.index);
Assert.AreEqual(0, add.OutputNumConsumers(0));
Assert.AreEqual(0, add.NumControlInputs);
Assert.AreEqual(0, add.NumControlOutputs);
EXPECT_EQ(three, add_in_1.oper);
EXPECT_EQ(0, add_in_1.index);
EXPECT_EQ(0, add.OutputNumConsumers(0));
EXPECT_EQ(0, add.NumControlInputs);
EXPECT_EQ(0, add.NumControlOutputs);

Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s));
Assert.AreEqual(DataType.DtInt32, attr_value.Type);
Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s));
Assert.AreEqual(2, attr_value.I);
ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s));
EXPECT_EQ(DataType.DtInt32, attr_value.Type);
ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s));
EXPECT_EQ(2, attr_value.I);

// Placeholder oper now has a consumer.
Assert.AreEqual(1, feed.OutputNumConsumers(0));
EXPECT_EQ(1, feed.OutputNumConsumers(0));
TF_Input[] feed_port = feed.OutputConsumers(0, 1);
Assert.AreEqual(1, feed_port.Length);
Assert.AreEqual(add, feed_port[0].oper);
Assert.AreEqual(0, feed_port[0].index);
EXPECT_EQ(1, feed_port.Length);
EXPECT_EQ(add, feed_port[0].oper);
EXPECT_EQ(0, feed_port[0].index);

// The scalar const oper also has a consumer.
Assert.AreEqual(1, three.OutputNumConsumers(0));
EXPECT_EQ(1, three.OutputNumConsumers(0));
TF_Input[] three_port = three.OutputConsumers(0, 1);
Assert.AreEqual(add, three_port[0].oper);
Assert.AreEqual(1, three_port[0].index);
EXPECT_EQ(add, three_port[0].oper);
EXPECT_EQ(1, three_port[0].index);

// Serialize to GraphDef.
var graph_def = c_test_util.GetGraphDef(graph);
@@ -119,38 +121,38 @@ namespace TensorFlowNET.UnitTest
Assert.Fail($"Unexpected NodeDef: {n}");
}
}
Assert.IsTrue(found_placeholder);
Assert.IsTrue(found_scalar_const);
Assert.IsTrue(found_add);
ASSERT_TRUE(found_placeholder);
ASSERT_TRUE(found_scalar_const);
ASSERT_TRUE(found_add);

// Add another oper to the graph.
var neg = c_test_util.Neg(add, graph, s);
Assert.AreEqual(TF_Code.TF_OK, s.Code);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

// Serialize to NodeDef.
var node_def = c_test_util.GetNodeDef(neg);

// Validate NodeDef is what we expect.
Assert.IsTrue(c_test_util.IsNeg(node_def, "add"));
ASSERT_TRUE(c_test_util.IsNeg(node_def, "add"));

// Serialize to GraphDef.
var graph_def2 = c_test_util.GetGraphDef(graph);

// Compare with first GraphDef + added NodeDef.
graph_def.Node.Add(node_def);
Assert.AreEqual(graph_def.ToString(), graph_def2.ToString());
EXPECT_EQ(graph_def.ToString(), graph_def2.ToString());

// Look up some nodes by name.
Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
Assert.AreEqual(neg, neg2);
EXPECT_EQ(neg, neg2);
var node_def2 = c_test_util.GetNodeDef(neg2);
Assert.AreEqual(node_def.ToString(), node_def2.ToString());
EXPECT_EQ(node_def.ToString(), node_def2.ToString());

Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
Assert.AreEqual(feed, feed2);
EXPECT_EQ(feed, feed2);
node_def = c_test_util.GetNodeDef(feed);
node_def2 = c_test_util.GetNodeDef(feed2);
Assert.AreEqual(node_def.ToString(), node_def2.ToString());
EXPECT_EQ(node_def.ToString(), node_def2.ToString());

// Test iterating through the nodes of a graph.
found_placeholder = false;
@@ -189,13 +191,106 @@ namespace TensorFlowNET.UnitTest
}
}

Assert.IsTrue(found_placeholder);
Assert.IsTrue(found_scalar_const);
Assert.IsTrue(found_add);
Assert.IsTrue(found_neg);
ASSERT_TRUE(found_placeholder);
ASSERT_TRUE(found_scalar_const);
ASSERT_TRUE(found_add);
ASSERT_TRUE(found_neg);

graph.Dispose();
s.Dispose();
}

/// <summary>
/// Port from c_api_test.cc
/// `TEST(CAPI, ImportGraphDef)`
/// </summary>
[TestMethod]
public void c_api_ImportGraphDef()
{
var s = new Status();
var graph = new Graph();

// Create a simple graph.
c_test_util.Placeholder(graph, s);
var oper = c_test_util.ScalarConst(3, graph, s);
c_test_util.Neg(oper, graph, s);

// Export to a GraphDef.
var graph_def = new Buffer();
c_api.TF_GraphToGraphDef(graph, graph_def, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

// Import it, with a prefix, in a fresh graph.
graph.Dispose();
graph = new Graph();
var opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

Operation scalar = c_api.TF_GraphOperationByName(graph, "imported/scalar");
Operation feed = c_api.TF_GraphOperationByName(graph, "imported/feed");
Operation neg = c_api.TF_GraphOperationByName(graph, "imported/neg");

// Test basic structure of the imported graph.
EXPECT_EQ(0, scalar.NumInputs);
EXPECT_EQ(0, feed.NumInputs);
EXPECT_EQ(1, neg.NumInputs);

var neg_input = neg.Input(0);
EXPECT_EQ(scalar, neg_input.oper);
EXPECT_EQ(0, neg_input.index);

// Test that we can't see control edges involving the source and sink nodes.
EXPECT_EQ(0, scalar.NumControlInputs);
EXPECT_EQ(0, scalar.ControlInputs(100).Length);
EXPECT_EQ(0, scalar.NumControlOutputs);
EXPECT_EQ(0, scalar.ControlOutputs(100).Length);

EXPECT_EQ(0, feed.NumControlInputs);
EXPECT_EQ(0, feed.ControlInputs(100).Length);
EXPECT_EQ(0, feed.NumControlOutputs);
EXPECT_EQ(0, feed.ControlOutputs(100).Length);

EXPECT_EQ(0, neg.NumControlInputs);
EXPECT_EQ(0, neg.ControlInputs(100).Length);
EXPECT_EQ(0, neg.NumControlOutputs);
EXPECT_EQ(0, neg.ControlOutputs(100).Length);

// Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph.
c_api.TF_DeleteImportGraphDefOptions(opts);
opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0));
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);

Operation scalar2 = c_api.TF_GraphOperationByName(graph, "imported2/scalar");
Operation feed2 = c_api.TF_GraphOperationByName(graph, "imported2/feed");
Operation neg2 = c_api.TF_GraphOperationByName(graph, "imported2/neg");

// Check input mapping
neg_input = neg.Input(0);
EXPECT_EQ(scalar, neg_input.oper);
EXPECT_EQ(0, neg_input.index);

// Check return outputs
var return_outputs = graph.ReturnOutputs(results);
ASSERT_EQ(2, return_outputs.Length);
EXPECT_EQ(feed2, return_outputs[0].oper);
EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
EXPECT_EQ(0, return_outputs[1].index);

// Check return operation
var num_return_opers = graph.ReturnOperations(results);
ASSERT_EQ(1, num_return_opers);
}
}
}

+ 19
- 19
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -10,7 +10,7 @@ using Tensorflow;
namespace TensorFlowNET.UnitTest
{
[TestClass]
public class TensorTest
public class TensorTest : CApiTest
{
/// <summary>
/// Port from c_api_test.cc
@@ -22,10 +22,10 @@ namespace TensorFlowNET.UnitTest
ulong num_bytes = 6 * sizeof(float);
long[] dims = { 2, 3 };
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
Assert.AreEqual(TF_DataType.TF_FLOAT, t.dtype);
Assert.AreEqual(2, t.NDims);
EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
EXPECT_EQ(2, t.NDims);
Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape));
Assert.AreEqual(num_bytes, t.bytesize);
EXPECT_EQ(num_bytes, t.bytesize);
t.Dispose();
}

@@ -41,11 +41,11 @@ namespace TensorFlowNET.UnitTest
var tensor = new Tensor(nd);
var array = tensor.Data<float>();

Assert.AreEqual(tensor.dtype, TF_DataType.TF_FLOAT);
Assert.AreEqual(tensor.rank, nd.ndim);
Assert.AreEqual(tensor.shape[0], nd.shape[0]);
Assert.AreEqual(tensor.shape[1], nd.shape[1]);
Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float));
EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim);
EXPECT_EQ(tensor.shape[0], nd.shape[0]);
EXPECT_EQ(tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array));
}

@@ -66,20 +66,20 @@ namespace TensorFlowNET.UnitTest
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);

Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.AreEqual(-1, num_dims);
EXPECT_EQ(-1, num_dims);

// Set the shape to be unknown, expect no change.
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
Assert.AreEqual(-1, num_dims);
EXPECT_EQ(-1, num_dims);

// Set the shape to be 2 x Unknown
long[] dims = { 2, -1 };
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
Assert.AreEqual(2, num_dims);
EXPECT_EQ(2, num_dims);

// Get the dimension vector appropriately.
var returned_dims = new long[dims.Length];
@@ -103,9 +103,9 @@ namespace TensorFlowNET.UnitTest
Assert.IsTrue(s.Code == TF_Code.TF_OK);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.AreEqual(2, num_dims);
Assert.AreEqual(2, returned_dims[0]);
Assert.AreEqual(3, returned_dims[1]);
EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, returned_dims[0]);
EXPECT_EQ(3, returned_dims[1]);

// Try to set 'unknown' with same rank on the shape and see that
// it doesn't change.
@@ -115,9 +115,9 @@ namespace TensorFlowNET.UnitTest
Assert.IsTrue(s.Code == TF_Code.TF_OK);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.AreEqual(2, num_dims);
Assert.AreEqual(2, returned_dims[0]);
Assert.AreEqual(3, returned_dims[1]);
EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, returned_dims[0]);
EXPECT_EQ(3, returned_dims[1]);

// Try to fetch a shape with the wrong num_dims
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
@@ -135,7 +135,7 @@ namespace TensorFlowNET.UnitTest

num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.AreEqual(0, num_dims);
EXPECT_EQ(0, num_dims);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s);
//Assert.IsTrue(s.Code == TF_Code.TF_OK);



Loading…
Cancel
Save