@@ -171,6 +171,34 @@ namespace Tensorflow | |||
return $"{name}_{_names_in_use[name_key]}"; | |||
} | |||
public TF_Output[] ReturnOutputs(IntPtr results) | |||
{ | |||
IntPtr return_output_handle = IntPtr.Zero; | |||
int num_return_outputs = 0; | |||
c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); | |||
TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | |||
for (int i = 0; i < num_return_outputs; i++) | |||
{ | |||
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(return_output_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||
} | |||
return return_outputs; | |||
} | |||
public Operation[] ReturnOperations(IntPtr results) | |||
{ | |||
IntPtr return_oper_handle = IntPtr.Zero; | |||
int num_return_opers = 0; | |||
c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_opers, ref return_oper_handle); | |||
Operation[] return_opers = new Operation[num_return_opers]; | |||
for (int i = 0; i < num_return_opers; i++) | |||
{ | |||
// return_opers[i] = Marshal.PtrToStructure<TF_Output>(return_oper_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||
} | |||
return return_opers; | |||
} | |||
public Operation[] get_operations() | |||
{ | |||
return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
@@ -15,6 +15,13 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_DeleteGraph(IntPtr graph); | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="opts">TF_ImportGraphDefOptions*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | |||
@@ -31,6 +38,29 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||
/// <summary> | |||
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | |||
/// a bad status on error. Otherwise, returns a populated | |||
/// TF_ImportGraphDefResults instance. The returned instance must be deleted via | |||
/// TF_DeleteImportGraphDefResults(). | |||
/// </summary> | |||
/// <param name="graph">TF_Graph*</param> | |||
/// <param name="graph_def">const TF_Buffer*</param> | |||
/// <param name="options">const TF_ImportGraphDefOptions*</param> | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns>TF_ImportGraphDefResults*</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); | |||
/// <summary> | |||
/// Import the graph serialized in `graph_def` into `graph`. | |||
/// </summary> | |||
/// <param name="graph">TF_Graph*</param> | |||
/// <param name="graph_def">TF_Buffer*</param> | |||
/// <param name="options">TF_ImportGraphDefOptions*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); | |||
/// <summary> | |||
/// Iterate through the operations of a graph. | |||
/// </summary> | |||
@@ -80,7 +110,96 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); | |||
/// <summary> | |||
/// Set any imported nodes with input `src_name:src_index` to have that input | |||
/// replaced with `dst`. `src_name` refers to a node in the graph to be imported, | |||
/// `dst` references a node already existing in the graph being imported into. | |||
/// `src_name` is copied and has no lifetime requirements. | |||
/// </summary> | |||
/// <param name="opts">TF_ImportGraphDefOptions*</param> | |||
/// <param name="src_name">const char*</param> | |||
/// <param name="src_index">int</param> | |||
/// <param name="dst">TF_Output</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsAddInputMapping(IntPtr opts, string src_name, int src_index, TF_Output dst); | |||
/// <summary> | |||
/// Add an operation in `graph_def` to be returned via the `return_opers` output | |||
/// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no | |||
// lifetime requirements. | |||
/// </summary> | |||
/// <param name="opts">TF_ImportGraphDefOptions* opts</param> | |||
/// <param name="oper_name">const char*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsAddReturnOperation(IntPtr opts, string oper_name); | |||
/// <summary> | |||
/// Add an output in `graph_def` to be returned via the `return_outputs` output | |||
/// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | |||
/// mapping, the corresponding existing tensor in `graph` will be returned. | |||
/// `oper_name` is copied and has no lifetime requirements. | |||
/// </summary> | |||
/// <param name="opts">TF_ImportGraphDefOptions*</param> | |||
/// <param name="oper_name">const char*</param> | |||
/// <param name="index">int</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsAddReturnOutput(IntPtr opts, string oper_name, int index); | |||
/// <summary> | |||
/// Returns the number of return operations added via | |||
/// TF_ImportGraphDefOptionsAddReturnOperation(). | |||
/// </summary> | |||
/// <param name="opts"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_ImportGraphDefOptionsNumReturnOperations(IntPtr opts); | |||
/// <summary> | |||
/// Returns the number of return outputs added via | |||
/// TF_ImportGraphDefOptionsAddReturnOutput(). | |||
/// </summary> | |||
/// <param name="opts">const TF_ImportGraphDefOptions*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); | |||
/// <summary> | |||
/// Set the prefix to be prepended to the names of nodes in `graph_def` that will | |||
/// be imported into `graph`. `prefix` is copied and has no lifetime | |||
/// requirements. | |||
/// </summary> | |||
/// <param name="ops"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); | |||
/// <summary> | |||
/// Fetches the return operations requested via | |||
/// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched | |||
/// operations is returned in `num_opers`. The array of return operations is | |||
/// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. | |||
/// </summary> | |||
/// <param name="results">TF_ImportGraphDefResults*</param> | |||
/// <param name="num_opers">int*</param> | |||
/// <param name="opers">TF_Operation***</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref IntPtr opers); | |||
/// <summary> | |||
/// Fetches the return outputs requested via | |||
/// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is | |||
/// returned in `num_outputs`. The array of return outputs is returned in | |||
/// `outputs`. `*outputs` is owned by and has the lifetime of `results`. | |||
/// </summary> | |||
/// <param name="results">TF_ImportGraphDefResults* results</param> | |||
/// <param name="num_outputs">int*</param> | |||
/// <param name="outputs">TF_Output**</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_NewGraph(); | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern IntPtr TF_NewGraph(); | |||
public static extern IntPtr TF_NewImportGraphDefOptions(); | |||
} | |||
} |
@@ -41,8 +41,25 @@ namespace Tensorflow | |||
} | |||
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
public Operation[] ControlInputs(int max_control_inputs) | |||
{ | |||
var control_inputs = new Operation[NumControlInputs]; | |||
var control_input_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs); | |||
c_api.TF_OperationGetControlInputs(_handle, control_input_handles, max_control_inputs); | |||
return control_inputs; | |||
} | |||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||
public Operation[] ControlOutputs(int max_control_outputs) | |||
{ | |||
var control_outputs = new Operation[NumControlOutputs]; | |||
var control_output_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | |||
c_api.TF_OperationGetControlInputs(_handle, control_output_handles, max_control_outputs); | |||
return control_outputs; | |||
} | |||
private Tensor[] _outputs; | |||
public Tensor[] outputs => _outputs; | |||
public Tensor[] inputs; | |||
@@ -49,6 +49,35 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | |||
/// <summary> | |||
/// Get list of all control inputs to an operation. `control_inputs` must | |||
/// point to an array of length `max_control_inputs` (ideally set to | |||
/// TF_OperationNumControlInputs(oper)). Returns the number of control | |||
/// inputs (should match TF_OperationNumControlInputs(oper)). | |||
/// </summary> | |||
/// <param name="oper">TF_Operation*</param> | |||
/// <param name="control_inputs">TF_Operation**</param> | |||
/// <param name="max_control_inputs"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationGetControlInputs(IntPtr oper, IntPtr control_inputs, int max_control_inputs); | |||
/// <summary> | |||
/// Get the list of operations that have `*oper` as a control input. | |||
/// `control_outputs` must point to an array of length at least | |||
/// `max_control_outputs` (ideally set to | |||
/// TF_OperationNumControlOutputs(oper)). Beware that a concurrent | |||
/// modification of the graph can increase the number of control | |||
/// outputs. Returns the number of control outputs (should match | |||
/// TF_OperationNumControlOutputs(oper)). | |||
/// </summary> | |||
/// <param name="oper">TF_Operation*</param> | |||
/// <param name="control_outputs">TF_Operation**</param> | |||
/// <param name="max_control_outputs"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationGetControlOutputs(IntPtr oper, IntPtr control_outputs, int max_control_outputs); | |||
/// <summary> | |||
/// TF_Output producer = TF_OperationInput(consumer); | |||
/// There is an edge from producer.oper's output (given by | |||
@@ -105,14 +134,19 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Get list of all current consumers of a specific output of an | |||
/// operation. | |||
/// operation. `consumers` must point to an array of length at least | |||
/// `max_consumers` (ideally set to | |||
/// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent | |||
/// modification of the graph can increase the number of consumers of | |||
/// an operation. Returns the number of output consumers (should match | |||
/// TF_OperationOutputNumConsumers(oper_out)). | |||
/// </summary> | |||
/// <param name="oper_out"></param> | |||
/// <param name="consumers"></param> | |||
/// <param name="max_consumers"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers); | |||
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, int max_consumers); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | |||
@@ -0,0 +1,25 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
public class CApiTest | |||
{ | |||
public void EXPECT_EQ(object expected, object actual) | |||
{ | |||
Assert.AreEqual(expected, actual); | |||
} | |||
public void ASSERT_EQ(object expected, object actual) | |||
{ | |||
Assert.AreEqual(expected, actual); | |||
} | |||
public void ASSERT_TRUE(bool condition) | |||
{ | |||
Assert.IsTrue(condition); | |||
} | |||
} | |||
} |
@@ -1,13 +1,15 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
using Tensorflow; | |||
using Buffer = Tensorflow.Buffer; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
[TestClass] | |||
public class GraphTest | |||
public class GraphTest : CApiTest | |||
{ | |||
/// <summary> | |||
/// Port from c_api_test.cc | |||
@@ -21,74 +23,74 @@ namespace TensorFlowNET.UnitTest | |||
// Make a placeholder operation. | |||
var feed = c_test_util.Placeholder(graph, s); | |||
Assert.AreEqual("feed", feed.Name); | |||
Assert.AreEqual("Placeholder", feed.OpType); | |||
Assert.AreEqual("", feed.Device); | |||
Assert.AreEqual(1, feed.NumOutputs); | |||
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0)); | |||
Assert.AreEqual(1, feed.OutputListLength("output")); | |||
Assert.AreEqual(0, feed.NumInputs); | |||
Assert.AreEqual(0, feed.OutputNumConsumers(0)); | |||
Assert.AreEqual(0, feed.NumControlInputs); | |||
Assert.AreEqual(0, feed.NumControlOutputs); | |||
EXPECT_EQ("feed", feed.Name); | |||
EXPECT_EQ("Placeholder", feed.OpType); | |||
EXPECT_EQ("", feed.Device); | |||
EXPECT_EQ(1, feed.NumOutputs); | |||
EXPECT_EQ(TF_DataType.TF_INT32, feed.OutputType(0)); | |||
EXPECT_EQ(1, feed.OutputListLength("output")); | |||
EXPECT_EQ(0, feed.NumInputs); | |||
EXPECT_EQ(0, feed.OutputNumConsumers(0)); | |||
EXPECT_EQ(0, feed.NumControlInputs); | |||
EXPECT_EQ(0, feed.NumControlOutputs); | |||
AttrValue attr_value = null; | |||
Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); | |||
Assert.AreEqual(attr_value.Type, DataType.DtInt32); | |||
ASSERT_TRUE(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); | |||
EXPECT_EQ(attr_value.Type, DataType.DtInt32); | |||
// Test not found errors in TF_Operation*() query functions. | |||
Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||
Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||
EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||
EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||
Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||
Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); | |||
EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); | |||
// Make a constant oper with the scalar "3". | |||
var three = c_test_util.ScalarConst(3, graph, s); | |||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
// Add oper. | |||
var add = c_test_util.Add(feed, three, graph, s); | |||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
// Test TF_Operation*() query functions. | |||
Assert.AreEqual("add", add.Name); | |||
Assert.AreEqual("AddN", add.OpType); | |||
Assert.AreEqual("", add.Device); | |||
Assert.AreEqual(1, add.NumOutputs); | |||
Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0)); | |||
Assert.AreEqual(1, add.OutputListLength("sum")); | |||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
Assert.AreEqual(2, add.InputListLength("inputs")); | |||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0)); | |||
Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1)); | |||
EXPECT_EQ("add", add.Name); | |||
EXPECT_EQ("AddN", add.OpType); | |||
EXPECT_EQ("", add.Device); | |||
EXPECT_EQ(1, add.NumOutputs); | |||
EXPECT_EQ(TF_DataType.TF_INT32, add.OutputType(0)); | |||
EXPECT_EQ(1, add.OutputListLength("sum")); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
EXPECT_EQ(2, add.InputListLength("inputs")); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(0)); | |||
EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(1)); | |||
var add_in_0 = add.Input(0); | |||
Assert.AreEqual(feed, add_in_0.oper); | |||
Assert.AreEqual(0, add_in_0.index); | |||
EXPECT_EQ(feed, add_in_0.oper); | |||
EXPECT_EQ(0, add_in_0.index); | |||
var add_in_1 = add.Input(1); | |||
Assert.AreEqual(three, add_in_1.oper); | |||
Assert.AreEqual(0, add_in_1.index); | |||
Assert.AreEqual(0, add.OutputNumConsumers(0)); | |||
Assert.AreEqual(0, add.NumControlInputs); | |||
Assert.AreEqual(0, add.NumControlOutputs); | |||
EXPECT_EQ(three, add_in_1.oper); | |||
EXPECT_EQ(0, add_in_1.index); | |||
EXPECT_EQ(0, add.OutputNumConsumers(0)); | |||
EXPECT_EQ(0, add.NumControlInputs); | |||
EXPECT_EQ(0, add.NumControlOutputs); | |||
Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | |||
Assert.AreEqual(DataType.DtInt32, attr_value.Type); | |||
Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | |||
Assert.AreEqual(2, attr_value.I); | |||
ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | |||
EXPECT_EQ(DataType.DtInt32, attr_value.Type); | |||
ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | |||
EXPECT_EQ(2, attr_value.I); | |||
// Placeholder oper now has a consumer. | |||
Assert.AreEqual(1, feed.OutputNumConsumers(0)); | |||
EXPECT_EQ(1, feed.OutputNumConsumers(0)); | |||
TF_Input[] feed_port = feed.OutputConsumers(0, 1); | |||
Assert.AreEqual(1, feed_port.Length); | |||
Assert.AreEqual(add, feed_port[0].oper); | |||
Assert.AreEqual(0, feed_port[0].index); | |||
EXPECT_EQ(1, feed_port.Length); | |||
EXPECT_EQ(add, feed_port[0].oper); | |||
EXPECT_EQ(0, feed_port[0].index); | |||
// The scalar const oper also has a consumer. | |||
Assert.AreEqual(1, three.OutputNumConsumers(0)); | |||
EXPECT_EQ(1, three.OutputNumConsumers(0)); | |||
TF_Input[] three_port = three.OutputConsumers(0, 1); | |||
Assert.AreEqual(add, three_port[0].oper); | |||
Assert.AreEqual(1, three_port[0].index); | |||
EXPECT_EQ(add, three_port[0].oper); | |||
EXPECT_EQ(1, three_port[0].index); | |||
// Serialize to GraphDef. | |||
var graph_def = c_test_util.GetGraphDef(graph); | |||
@@ -119,38 +121,38 @@ namespace TensorFlowNET.UnitTest | |||
Assert.Fail($"Unexpected NodeDef: {n}"); | |||
} | |||
} | |||
Assert.IsTrue(found_placeholder); | |||
Assert.IsTrue(found_scalar_const); | |||
Assert.IsTrue(found_add); | |||
ASSERT_TRUE(found_placeholder); | |||
ASSERT_TRUE(found_scalar_const); | |||
ASSERT_TRUE(found_add); | |||
// Add another oper to the graph. | |||
var neg = c_test_util.Neg(add, graph, s); | |||
Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
// Serialize to NodeDef. | |||
var node_def = c_test_util.GetNodeDef(neg); | |||
// Validate NodeDef is what we expect. | |||
Assert.IsTrue(c_test_util.IsNeg(node_def, "add")); | |||
ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); | |||
// Serialize to GraphDef. | |||
var graph_def2 = c_test_util.GetGraphDef(graph); | |||
// Compare with first GraphDef + added NodeDef. | |||
graph_def.Node.Add(node_def); | |||
Assert.AreEqual(graph_def.ToString(), graph_def2.ToString()); | |||
EXPECT_EQ(graph_def.ToString(), graph_def2.ToString()); | |||
// Look up some nodes by name. | |||
Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | |||
Assert.AreEqual(neg, neg2); | |||
EXPECT_EQ(neg, neg2); | |||
var node_def2 = c_test_util.GetNodeDef(neg2); | |||
Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||
EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | |||
Assert.AreEqual(feed, feed2); | |||
EXPECT_EQ(feed, feed2); | |||
node_def = c_test_util.GetNodeDef(feed); | |||
node_def2 = c_test_util.GetNodeDef(feed2); | |||
Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||
EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
// Test iterating through the nodes of a graph. | |||
found_placeholder = false; | |||
@@ -189,13 +191,106 @@ namespace TensorFlowNET.UnitTest | |||
} | |||
} | |||
Assert.IsTrue(found_placeholder); | |||
Assert.IsTrue(found_scalar_const); | |||
Assert.IsTrue(found_add); | |||
Assert.IsTrue(found_neg); | |||
ASSERT_TRUE(found_placeholder); | |||
ASSERT_TRUE(found_scalar_const); | |||
ASSERT_TRUE(found_add); | |||
ASSERT_TRUE(found_neg); | |||
graph.Dispose(); | |||
s.Dispose(); | |||
} | |||
/// <summary> | |||
/// Port from c_api_test.cc | |||
/// `TEST(CAPI, ImportGraphDef)` | |||
/// </summary> | |||
[TestMethod] | |||
public void c_api_ImportGraphDef() | |||
{ | |||
var s = new Status(); | |||
var graph = new Graph(); | |||
// Create a simple graph. | |||
c_test_util.Placeholder(graph, s); | |||
var oper = c_test_util.ScalarConst(3, graph, s); | |||
c_test_util.Neg(oper, graph, s); | |||
// Export to a GraphDef. | |||
var graph_def = new Buffer(); | |||
c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
// Import it, with a prefix, in a fresh graph. | |||
graph.Dispose(); | |||
graph = new Graph(); | |||
var opts = c_api.TF_NewImportGraphDefOptions(); | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
Operation scalar = c_api.TF_GraphOperationByName(graph, "imported/scalar"); | |||
Operation feed = c_api.TF_GraphOperationByName(graph, "imported/feed"); | |||
Operation neg = c_api.TF_GraphOperationByName(graph, "imported/neg"); | |||
// Test basic structure of the imported graph. | |||
EXPECT_EQ(0, scalar.NumInputs); | |||
EXPECT_EQ(0, feed.NumInputs); | |||
EXPECT_EQ(1, neg.NumInputs); | |||
var neg_input = neg.Input(0); | |||
EXPECT_EQ(scalar, neg_input.oper); | |||
EXPECT_EQ(0, neg_input.index); | |||
// Test that we can't see control edges involving the source and sink nodes. | |||
EXPECT_EQ(0, scalar.NumControlInputs); | |||
EXPECT_EQ(0, scalar.ControlInputs(100).Length); | |||
EXPECT_EQ(0, scalar.NumControlOutputs); | |||
EXPECT_EQ(0, scalar.ControlOutputs(100).Length); | |||
EXPECT_EQ(0, feed.NumControlInputs); | |||
EXPECT_EQ(0, feed.ControlInputs(100).Length); | |||
EXPECT_EQ(0, feed.NumControlOutputs); | |||
EXPECT_EQ(0, feed.ControlOutputs(100).Length); | |||
EXPECT_EQ(0, neg.NumControlInputs); | |||
EXPECT_EQ(0, neg.ControlInputs(100).Length); | |||
EXPECT_EQ(0, neg.NumControlOutputs); | |||
EXPECT_EQ(0, neg.ControlOutputs(100).Length); | |||
// Import it again, with an input mapping, return outputs, and a return | |||
// operation, into the same graph. | |||
c_api.TF_DeleteImportGraphDefOptions(opts); | |||
opts = c_api.TF_NewImportGraphDefOptions(); | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | |||
c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); | |||
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
Operation scalar2 = c_api.TF_GraphOperationByName(graph, "imported2/scalar"); | |||
Operation feed2 = c_api.TF_GraphOperationByName(graph, "imported2/feed"); | |||
Operation neg2 = c_api.TF_GraphOperationByName(graph, "imported2/neg"); | |||
// Check input mapping | |||
neg_input = neg.Input(0); | |||
EXPECT_EQ(scalar, neg_input.oper); | |||
EXPECT_EQ(0, neg_input.index); | |||
// Check return outputs | |||
var return_outputs = graph.ReturnOutputs(results); | |||
ASSERT_EQ(2, return_outputs.Length); | |||
EXPECT_EQ(feed2, return_outputs[0].oper); | |||
EXPECT_EQ(0, return_outputs[0].index); | |||
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped | |||
EXPECT_EQ(0, return_outputs[1].index); | |||
// Check return operation | |||
var num_return_opers = graph.ReturnOperations(results); | |||
ASSERT_EQ(1, num_return_opers); | |||
} | |||
} | |||
} |
@@ -10,7 +10,7 @@ using Tensorflow; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
[TestClass] | |||
public class TensorTest | |||
public class TensorTest : CApiTest | |||
{ | |||
/// <summary> | |||
/// Port from c_api_test.cc | |||
@@ -22,10 +22,10 @@ namespace TensorFlowNET.UnitTest | |||
ulong num_bytes = 6 * sizeof(float); | |||
long[] dims = { 2, 3 }; | |||
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||
Assert.AreEqual(TF_DataType.TF_FLOAT, t.dtype); | |||
Assert.AreEqual(2, t.NDims); | |||
EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | |||
EXPECT_EQ(2, t.NDims); | |||
Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); | |||
Assert.AreEqual(num_bytes, t.bytesize); | |||
EXPECT_EQ(num_bytes, t.bytesize); | |||
t.Dispose(); | |||
} | |||
@@ -41,11 +41,11 @@ namespace TensorFlowNET.UnitTest | |||
var tensor = new Tensor(nd); | |||
var array = tensor.Data<float>(); | |||
Assert.AreEqual(tensor.dtype, TF_DataType.TF_FLOAT); | |||
Assert.AreEqual(tensor.rank, nd.ndim); | |||
Assert.AreEqual(tensor.shape[0], nd.shape[0]); | |||
Assert.AreEqual(tensor.shape[1], nd.shape[1]); | |||
Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||
EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | |||
EXPECT_EQ(tensor.rank, nd.ndim); | |||
EXPECT_EQ(tensor.shape[0], nd.shape[0]); | |||
EXPECT_EQ(tensor.shape[1], nd.shape[1]); | |||
EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | |||
} | |||
@@ -66,20 +66,20 @@ namespace TensorFlowNET.UnitTest | |||
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.AreEqual(-1, num_dims); | |||
EXPECT_EQ(-1, num_dims); | |||
// Set the shape to be unknown, expect no change. | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.AreEqual(-1, num_dims); | |||
EXPECT_EQ(-1, num_dims); | |||
// Set the shape to be 2 x Unknown | |||
long[] dims = { 2, -1 }; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.AreEqual(2, num_dims); | |||
EXPECT_EQ(2, num_dims); | |||
// Get the dimension vector appropriately. | |||
var returned_dims = new long[dims.Length]; | |||
@@ -103,9 +103,9 @@ namespace TensorFlowNET.UnitTest | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.AreEqual(2, num_dims); | |||
Assert.AreEqual(2, returned_dims[0]); | |||
Assert.AreEqual(3, returned_dims[1]); | |||
EXPECT_EQ(2, num_dims); | |||
EXPECT_EQ(2, returned_dims[0]); | |||
EXPECT_EQ(3, returned_dims[1]); | |||
// Try to set 'unknown' with same rank on the shape and see that | |||
// it doesn't change. | |||
@@ -115,9 +115,9 @@ namespace TensorFlowNET.UnitTest | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.AreEqual(2, num_dims); | |||
Assert.AreEqual(2, returned_dims[0]); | |||
Assert.AreEqual(3, returned_dims[1]); | |||
EXPECT_EQ(2, num_dims); | |||
EXPECT_EQ(2, returned_dims[0]); | |||
EXPECT_EQ(3, returned_dims[1]); | |||
// Try to fetch a shape with the wrong num_dims | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | |||
@@ -135,7 +135,7 @@ namespace TensorFlowNET.UnitTest | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.AreEqual(0, num_dims); | |||
EXPECT_EQ(0, num_dims); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); | |||
//Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||