You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CreateOpFromTfOperationTest.cs 9.8 kB


  1. using System;
  2. using System.Linq;
  3. using Microsoft.VisualStudio.TestTools.UnitTesting;
  4. using Tensorflow;
  5. using Tensorflow.Operations;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.ops_test
  8. {
  9. /// <summary>
  10. /// excerpt of tensorflow/python/framework/ops_test.py
  11. /// # These cases test the private Graph._create_op_from_tf_operation
  12. /// # method. Arguably we should only test the public APIs that depend on this
  13. /// # method. However, this logic is complex and tricky, and it can be difficult to
  14. /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
  15. /// # the control flow context isn't set properly, but a more complicated use case
  16. /// # that might not be obvious to test will fail). Thus we instead explicitly test
  17. /// # the low-level behavior.
  18. /// </summary>
  19. [TestClass]
  20. public class CreateOpFromTfOperationTest : PythonTest
  21. {
  22. [TestMethod]
  23. public void TestShape()
  24. {
  25. using (var g = tf.Graph().as_default())
  26. {
  27. var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}});
  28. var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
  29. var op = g._create_op_from_tf_operation(c_op);
  30. Assert.AreEqual("myop", op.name);
  31. Assert.AreEqual("Identity", op.type);
  32. Assert.AreEqual(1, len(op.outputs));
  33. assertItemsEqual(new[] {2, 3}, op.outputs[0].shape);
  34. }
  35. }
  36. [TestMethod]
  37. public void TestUniqueName()
  38. {
  39. var graph = tf.Graph().as_default();
  40. //var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]);
  41. //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]);
  42. //var op = g._create_op_from_tf_operation(c_op);
  43. //var op2 = g._create_op_from_tf_operation(c_op2);
  44. var op = constant_op.constant(0, name: "myop").op;
  45. var op2 = constant_op.constant(0, name: "myop_1").op;
  46. // Create ops with same names as op1 and op2. We expect the new names to be
  47. // uniquified.
  48. var op3 = constant_op.constant(0, name: "myop").op;
  49. var op4 = constant_op.constant(0, name: "myop_1").op;
  50. self.assertEqual(op.name, "myop");
  51. self.assertEqual(op2.name, "myop_1");
  52. self.assertEqual(op3.name, "myop_2");
  53. self.assertEqual(op4.name, "myop_1_1");
  54. }
  55. [Ignore("need tesnroflow expose UpdateEdge API")]
  56. [TestMethod]
  57. public void TestCond()
  58. {
  59. var g = tf.Graph().as_default();
  60. var x = constant_op.constant(10);
  61. var true_fn = new Func<Tensor>(() =>
  62. {
  63. var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
  64. var new_ops = g._add_new_tf_operations();
  65. self.assertEqual(len(new_ops), 1);
  66. return x;
  67. });
  68. control_flow_ops.cond(x < 10, true_fn, () => x);
  69. var op = g.get_operation_by_name("cond/myop");
  70. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true);
  71. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
  72. self.assertIsNotNone(op);
  73. self.assertEqual(op.name, "cond/myop");
  74. self.assertEqual(op.type, "Identity");
  75. //self.assertEqual(op.outputs, new object[0]);
  76. var op_input = op.inputs[0].op;
  77. self.assertEqual(op_input.type, "Switch");
  78. self.assertEqual(op_input.inputs[0].name, x.name);
  79. self.assertEqual(op.graph, g);
  80. self.assertIsNotNone(op._get_control_flow_context());
  81. var cond_text = op._get_control_flow_context() as ControlFlowContext;
  82. self.assertEqual(cond_text.name, "cond/cond_text");
  83. }
  84. [Ignore("Todo: Port")]
  85. [TestMethod]
  86. public void TestWhileLoop()
  87. {
  88. var graph = tf.Graph().as_default();
  89. Operation x=null;
  90. x = constant_op.constant(42);
  91. var body = new Func<int, int>(i =>
  92. {
  93. ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] {x},
  94. new Operation[0]);
  95. var new_ops = graph._add_new_tf_operations();
  96. self.assertEqual(len(new_ops), 1);
  97. return i;
  98. });
  99. // TODO: port control_flow_ops.while_loop
  100. //control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop");
  101. var op = graph.get_operation_by_name("myloop/myop");
  102. self.assertIsNotNone(op);
  103. self.assertEqual(op.name, "myloop/myop");
  104. self.assertEqual(op.type, "Identity");
  105. self.assertEqual(op.outputs.Length, 0);
  106. var op_input = op.inputs[0].op;
  107. self.assertEqual(op_input.type, "Enter");
  108. self.assertItemsEqual(op_input.inputs.OfType<Operation>().ToArray(), new[] {x});
  109. self.assertEqual(op.graph, graph);
  110. self.assertIsNotNone(op._get_control_flow_context());
  111. self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).name, "myloop/while_context");
  112. /*
  113. @test_util.run_v1_only("b/120545219")
  114. def testWhileLoop(self):
  115. g = ops.Graph()
  116. with g.as_default():
  117. x = test_ops.int_output()
  118. def body(i):
  119. ops._create_c_op(ops.get_default_graph(),
  120. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  121. new_ops = g._add_new_tf_operations()
  122. self.assertEqual(len(new_ops), 1)
  123. return i
  124. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  125. op = g.get_operation_by_name("myloop/myop")
  126. self.assertIsNotNone(op)
  127. self.assertEqual(op.name, "myloop/myop")
  128. self.assertEqual(op.type, "IntInput")
  129. self.assertEqual(op.outputs, [])
  130. op_input = op.inputs[0].op
  131. self.assertEqual(op_input.type, "Enter")
  132. self.assertEqual(list(op_input.inputs), [x])
  133. self.assertEqual(op.graph, g)
  134. # pylint: disable=protected-access
  135. self.assertIsNotNone(op._get_control_flow_context())
  136. self.assertEqual(op._get_control_flow_context().name,
  137. "myloop/while_context")
  138. # pylint: enable=protected-access
  139. */
  140. }
  141. [Ignore("Todo: Port")]
  142. [TestMethod]
  143. public void TestWhileLoopWithInternalControlDep()
  144. {
  145. /*
  146. @test_util.run_v1_only("b/120545219")
  147. def testWhileLoopWithInternalControlDep(self):
  148. g = ops.Graph()
  149. with g.as_default():
  150. x = test_ops.int_output()
  151. def body(i):
  152. c = constant_op.constant(1.0, name="c")
  153. ops._create_c_op(ops.get_default_graph(),
  154. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  155. with ops.control_dependencies([c]):
  156. new_ops = g._add_new_tf_operations()
  157. self.assertEqual(len(new_ops), 1)
  158. return i
  159. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  160. op = g.get_operation_by_name("myloop/myop")
  161. self.assertIsNotNone(op)
  162. c = g.get_operation_by_name("myloop/c")
  163. self.assertIsNotNone(c)
  164. # Internal control dep is preserved
  165. self.assertEqual(op.control_inputs, [c])
  166. */
  167. }
  168. [Ignore("Todo: Port")]
  169. [TestMethod]
  170. public void TestWhileLoopWithExternalControlDep()
  171. {
  172. /*
  173. @test_util.run_v1_only("b/120545219")
  174. def testWhileLoopWithExternalControlDep(self):
  175. g = ops.Graph()
  176. with g.as_default():
  177. x = test_ops.int_output()
  178. c = constant_op.constant(1.0)
  179. def body(i):
  180. ops._create_c_op(ops.get_default_graph(),
  181. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  182. with ops.control_dependencies([c]):
  183. new_ops = g._add_new_tf_operations()
  184. self.assertEqual(len(new_ops), 1)
  185. return i
  186. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  187. op = g.get_operation_by_name("myloop/myop")
  188. self.assertIsNotNone(op)
  189. # External control dep is removed and replaced with internal control dep
  190. self.assertNotEqual(op.control_inputs[0], c.op)
  191. self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
  192. */
  193. }
  194. }
  195. }