using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Collections.Generic; using System.Text; using Tensorflow; namespace TensorFlowNET.UnitTest { [TestClass] public class GraphTest { /// /// Port from c_api_test.cc /// `TEST(CAPI, Graph)` /// [TestMethod] public void c_api_Graph() { var s = new Status(); var graph = new Graph(); // Make a placeholder operation. var feed = c_test_util.Placeholder(graph, s); Assert.AreEqual("feed", feed.Name); Assert.AreEqual("Placeholder", feed.OpType); Assert.AreEqual("", feed.Device); Assert.AreEqual(1, feed.NumOutputs); Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0)); Assert.AreEqual(1, feed.OutputListLength("output")); Assert.AreEqual(0, feed.NumInputs); Assert.AreEqual(0, feed.OutputNumConsumers(0)); Assert.AreEqual(0, feed.NumControlInputs); Assert.AreEqual(0, feed.NumControlOutputs); AttrValue attr_value = null; Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); Assert.AreEqual(attr_value.Type, DataType.DtInt32); // Test not found errors in TF_Operation*() query functions. Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); // Make a constant oper with the scalar "3". var three = c_test_util.ScalarConst(3, graph, s); Assert.AreEqual(TF_Code.TF_OK, s.Code); // Add oper. var add = c_test_util.Add(feed, three, graph, s); Assert.AreEqual(TF_Code.TF_OK, s.Code); // Test TF_Operation*() query functions. Assert.AreEqual("add", add.Name); Assert.AreEqual("AddN", add.OpType); Assert.AreEqual("", add.Device); Assert.AreEqual(1, add.NumOutputs); Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0)); Assert.AreEqual(1, add.OutputListLength("sum")); Assert.AreEqual(TF_Code.TF_OK, s.Code); Assert.AreEqual(2, add.InputListLength("inputs")); Assert.AreEqual(TF_Code.TF_OK, s.Code); Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0)); Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1)); var add_in_0 = add.Input(0); Assert.AreEqual(feed, add_in_0.oper); Assert.AreEqual(0, add_in_0.index); var add_in_1 = add.Input(1); Assert.AreEqual(three, add_in_1.oper); Assert.AreEqual(0, add_in_1.index); Assert.AreEqual(0, add.OutputNumConsumers(0)); Assert.AreEqual(0, add.NumControlInputs); Assert.AreEqual(0, add.NumControlOutputs); Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); Assert.AreEqual(DataType.DtInt32, attr_value.Type); Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); Assert.AreEqual(2, attr_value.I); // Placeholder oper now has a consumer. Assert.AreEqual(1, feed.OutputNumConsumers(0)); TF_Input[] feed_port = feed.OutputConsumers(0, 1); Assert.AreEqual(1, feed_port.Length); Assert.AreEqual(add, feed_port[0].oper); Assert.AreEqual(0, feed_port[0].index); // The scalar const oper also has a consumer. Assert.AreEqual(1, three.OutputNumConsumers(0)); TF_Input[] three_port = three.OutputConsumers(0, 1); Assert.AreEqual(add, three_port[0].oper); Assert.AreEqual(1, three_port[0].index); // Serialize to GraphDef. var graph_def = c_test_util.GetGraphDef(graph); // Validate GraphDef is what we expect. bool found_placeholder = false; bool found_scalar_const = false; bool found_add = false; foreach (var n in graph_def.Node) { if (c_test_util.IsPlaceholder(n)) { Assert.IsFalse(found_placeholder); found_placeholder = true; } /*else if (IsScalarConst(n, 3)) { Assert.IsFalse(found_scalar_const); found_scalar_const = true; } else if (IsAddN(n, 2)) { Assert.IsFalse(found_add); found_add = true; } else { ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); }*/ } } } }