using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Collections.Generic; using Tensorflow; using Buffer = Tensorflow.Buffer; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest { [TestClass] public class GraphTest : CApiTest { /// /// Port from c_api_test.cc /// `TEST(CAPI, Graph)` /// [TestMethod] public void Graph() { var s = new Status(); var graph = new Graph(); // Make a placeholder operation. var feed = c_test_util.Placeholder(graph, s); EXPECT_EQ("feed", feed.name); EXPECT_EQ("Placeholder", feed.OpType); EXPECT_EQ("", feed.Device); EXPECT_EQ(1, feed.NumOutputs); EXPECT_EQ(TF_DataType.TF_INT32, feed.OutputType(0)); EXPECT_EQ(1, feed.OutputListLength("output")); EXPECT_EQ(0, feed.NumInputs); EXPECT_EQ(0, feed.OutputNumConsumers(0)); EXPECT_EQ(0, feed.NumControlInputs); EXPECT_EQ(0, feed.NumControlOutputs); AttrValue attr_value = null; ASSERT_TRUE(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); EXPECT_EQ(attr_value.Type, DataType.DtInt32); // Test not found errors in TF_Operation*() query functions. EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); // Make a constant oper with the scalar "3". var three = c_test_util.ScalarConst(3, graph, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); // Add oper. var add = c_test_util.Add(feed, three, graph, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); // Test TF_Operation*() query functions. EXPECT_EQ("add", add.name); EXPECT_EQ("AddN", add.OpType); EXPECT_EQ("", add.Device); EXPECT_EQ(1, add.NumOutputs); EXPECT_EQ(TF_DataType.TF_INT32, add.OutputType(0)); EXPECT_EQ(1, add.OutputListLength("sum")); EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(2, add.InputListLength("inputs")); EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(0)); EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(1)); var add_in_0 = add.Input(0); EXPECT_EQ(feed, add_in_0.oper); EXPECT_EQ(0, add_in_0.index); var add_in_1 = add.Input(1); EXPECT_EQ(three, add_in_1.oper); EXPECT_EQ(0, add_in_1.index); EXPECT_EQ(0, add.OutputNumConsumers(0)); EXPECT_EQ(0, add.NumControlInputs); EXPECT_EQ(0, add.NumControlOutputs); ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); EXPECT_EQ(DataType.DtInt32, attr_value.Type); ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); EXPECT_EQ(2, (int)attr_value.I); // Placeholder oper now has a consumer. EXPECT_EQ(1, feed.OutputNumConsumers(0)); TF_Input[] feed_port = feed.OutputConsumers(0, 1); EXPECT_EQ(1, feed_port.Length); EXPECT_EQ(add, feed_port[0].oper); EXPECT_EQ(0, feed_port[0].index); // The scalar const oper also has a consumer. EXPECT_EQ(1, three.OutputNumConsumers(0)); TF_Input[] three_port = three.OutputConsumers(0, 1); EXPECT_EQ(add, three_port[0].oper); EXPECT_EQ(1, three_port[0].index); // Serialize to GraphDef. var graph_def = c_test_util.GetGraphDef(graph); // Validate GraphDef is what we expect. bool found_placeholder = false; bool found_scalar_const = false; bool found_add = false; foreach (var n in graph_def.Node) { if (c_test_util.IsPlaceholder(n)) { Assert.IsFalse(found_placeholder); found_placeholder = true; } else if (c_test_util.IsScalarConst(n, 3)) { Assert.IsFalse(found_scalar_const); found_scalar_const = true; } else if (c_test_util.IsAddN(n, 2)) { Assert.IsFalse(found_add); found_add = true; } else { Assert.Fail($"Unexpected NodeDef: {n}"); } } ASSERT_TRUE(found_placeholder); ASSERT_TRUE(found_scalar_const); ASSERT_TRUE(found_add); // Add another oper to the graph. var neg = c_test_util.Neg(add, graph, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); // Serialize to NodeDef. var node_def = neg.node_def; // Validate NodeDef is what we expect. ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); // Serialize to GraphDef. var graph_def2 = c_test_util.GetGraphDef(graph); // Compare with first GraphDef + added NodeDef. graph_def.Node.Add(node_def); EXPECT_EQ(graph_def.ToString(), graph_def2.ToString()); // Look up some nodes by name. Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); EXPECT_EQ(neg, neg2); var node_def2 = neg2.node_def; EXPECT_EQ(node_def.ToString(), node_def2.ToString()); Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); EXPECT_EQ(feed, feed2); node_def = feed.node_def; node_def2 = feed2.node_def; EXPECT_EQ(node_def.ToString(), node_def2.ToString()); // Test iterating through the nodes of a graph. found_placeholder = false; found_scalar_const = false; found_add = false; bool found_neg = false; uint pos = 0; Operation oper; while ((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) { if (oper.Equals(feed)) { Assert.IsFalse(found_placeholder); found_placeholder = true; } else if (oper.Equals(three)) { Assert.IsFalse(found_scalar_const); found_scalar_const = true; } else if (oper.Equals(add)) { Assert.IsFalse(found_add); found_add = true; } else if (oper.Equals(neg)) { Assert.IsFalse(found_neg); found_neg = true; } else { node_def = oper.node_def; Assert.Fail($"Unexpected Node: {node_def.ToString()}"); } } ASSERT_TRUE(found_placeholder); ASSERT_TRUE(found_scalar_const); ASSERT_TRUE(found_add); ASSERT_TRUE(found_neg); graph.Dispose(); s.Dispose(); } /// /// Port from c_api_test.cc /// `TEST(CAPI, ImportGraphDef)` /// [TestMethod] public void ImportGraphDef() { var s = new Status(); var graph = new Graph().as_default(); // Create a simple graph. c_test_util.Placeholder(graph, s); var oper = c_test_util.ScalarConst(3, graph, s); c_test_util.Neg(oper, graph, s); // Export to a GraphDef. var graph_def = new Buffer(); c_api.TF_GraphToGraphDef(graph, graph_def, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); // Import it, with a prefix, in a fresh graph. graph.Dispose(); graph = new Graph().as_default(); var opts = c_api.TF_NewImportGraphDefOptions(); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); Operation scalar = graph.OperationByName("imported/scalar"); Operation feed = graph.OperationByName("imported/feed"); Operation neg = graph.OperationByName("imported/neg"); // Test basic structure of the imported graph. EXPECT_EQ(0, scalar.NumInputs); EXPECT_EQ(0, feed.NumInputs); EXPECT_EQ(1, neg.NumInputs); var neg_input = neg.Input(0); EXPECT_EQ(scalar, neg_input.oper); EXPECT_EQ(0, neg_input.index); // Test that we can't see control edges involving the source and sink nodes. EXPECT_EQ(0, scalar.NumControlInputs); EXPECT_EQ(0, scalar.GetControlInputs().Length); EXPECT_EQ(0, scalar.NumControlOutputs); EXPECT_EQ(0, scalar.GetControlOutputs().Length); EXPECT_EQ(0, feed.NumControlInputs); EXPECT_EQ(0, feed.GetControlInputs().Length); EXPECT_EQ(0, feed.NumControlOutputs); EXPECT_EQ(0, feed.GetControlOutputs().Length); EXPECT_EQ(0, neg.NumControlInputs); EXPECT_EQ(0, neg.GetControlInputs().Length); EXPECT_EQ(0, neg.NumControlOutputs); EXPECT_EQ(0, neg.GetControlOutputs().Length); // Import it again, with an input mapping, return outputs, and a return // operation, into the same graph. c_api.TF_DeleteImportGraphDefOptions(opts); opts = c_api.TF_NewImportGraphDefOptions(); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); Operation scalar2 = graph.OperationByName("imported2/scalar"); Operation feed2 = graph.OperationByName("imported2/feed"); Operation neg2 = graph.OperationByName("imported2/neg"); // Check input mapping neg_input = neg.Input(0); EXPECT_EQ(scalar, neg_input.oper); EXPECT_EQ(0, neg_input.index); // Check return outputs var return_outputs = graph.ReturnOutputs(results); ASSERT_EQ(2, return_outputs.Length); EXPECT_EQ(feed2, return_outputs[0].oper); EXPECT_EQ(0, return_outputs[0].index); EXPECT_EQ(scalar, return_outputs[1].oper); // remapped EXPECT_EQ(0, return_outputs[1].index); // Check return operation var return_opers = graph.ReturnOperations(results); ASSERT_EQ(1, return_opers.Length); EXPECT_EQ(scalar2, return_opers[0]); // not remapped c_api.TF_DeleteImportGraphDefResults(results); // Import again, with control dependencies, into the same graph. c_api.TF_DeleteImportGraphDefOptions(opts); opts = c_api.TF_NewImportGraphDefOptions(); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); var scalar3 = graph.OperationByName("imported3/scalar"); var feed3 = graph.OperationByName("imported3/feed"); var neg3 = graph.OperationByName("imported3/neg"); ASSERT_TRUE(scalar3 != IntPtr.Zero); ASSERT_TRUE(feed3 != IntPtr.Zero); ASSERT_TRUE(neg3 != IntPtr.Zero); // Check that newly-imported scalar and feed have control deps (neg3 will // inherit them from input) var control_inputs = scalar3.GetControlInputs(); ASSERT_EQ(2, scalar3.NumControlInputs); EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed2, control_inputs[1]); control_inputs = feed3.GetControlInputs(); ASSERT_EQ(2, feed3.NumControlInputs); EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed2, control_inputs[1]); // Export to a graph def so we can import a graph with control dependencies graph_def = new Buffer(); c_api.TF_GraphToGraphDef(graph, graph_def, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); // Import again, with remapped control dependency, into the same graph c_api.TF_DeleteImportGraphDefOptions(opts); opts = c_api.TF_NewImportGraphDefOptions(); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); ASSERT_EQ(TF_Code.TF_OK, s.Code); var scalar4 = graph.OperationByName("imported4/imported3/scalar"); var feed4 = graph.OperationByName("imported4/imported2/feed"); // Check that imported `imported3/scalar` has remapped control dep from // original graph and imported control dep control_inputs = scalar4.GetControlInputs(); ASSERT_EQ(2, scalar4.NumControlInputs); EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed4, control_inputs[1]); c_api.TF_DeleteImportGraphDefOptions(opts); // Can add nodes to the imported graph without trouble. c_test_util.Add(feed, scalar, graph, s); ASSERT_EQ(TF_Code.TF_OK, s.Code); } /// /// Port from c_api_test.cc /// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)` /// [TestMethod] public void ImportGraphDef_WithReturnOutputs() { var s = new Status(); var graph = new Graph().as_default(); // Create a graph with two nodes: x and 3 c_test_util.Placeholder(graph, s); ASSERT_TRUE(graph.OperationByName("feed") != null); var oper = c_test_util.ScalarConst(3, graph, s); ASSERT_TRUE(graph.OperationByName("scalar") != null); c_test_util.Neg(oper, graph, s); ASSERT_TRUE(graph.OperationByName("neg") != null); // Export to a GraphDef. var graph_def = graph.ToGraphDef(s); ASSERT_EQ(TF_Code.TF_OK, s.Code); // Import it in a fresh graph with return outputs. graph.Dispose(); graph = new Graph().as_default(); var opts = new ImportGraphDefOptions(); opts.AddReturnOutput("feed", 0); opts.AddReturnOutput("scalar", 0); EXPECT_EQ(2, opts.NumReturnOutputs); var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); ASSERT_EQ(TF_Code.TF_OK, s.Code); var scalar = graph.OperationByName("scalar"); var feed = graph.OperationByName("feed"); var neg = graph.OperationByName("neg"); ASSERT_TRUE(scalar != IntPtr.Zero); ASSERT_TRUE(feed != IntPtr.Zero); ASSERT_TRUE(neg != IntPtr.Zero); // Check return outputs EXPECT_EQ(feed, return_outputs[0].oper); EXPECT_EQ(0, return_outputs[0].index); EXPECT_EQ(scalar, return_outputs[1].oper); EXPECT_EQ(0, return_outputs[1].index); opts.Dispose(); graph_def.Dispose(); graph.Dispose(); s.Dispose(); } /// /// `TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings)` /// [TestMethod] public void ImportGraphDef_MissingUnusedInputMappings() { } [Ignore] [TestMethod] public void ImportGraphMeta() { var dir = "my-save-dir/"; using (var sess = tf.Session()) { var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); new_saver.restore(sess, dir + "my-model-10000"); var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); var batch_size = tf.size(labels); var logits = (tf.get_collection("logits") as List)[0] as Tensor; var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, logits: logits); } } } }