You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CreateOpFromTfOperationTest.cs 10 kB


  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Microsoft.VisualStudio.TestTools.UnitTesting;
  6. using Tensorflow;
  7. using Tensorflow.Operations;
  8. using static Tensorflow.Python;
  9. namespace TensorFlowNET.UnitTest.ops_test
  10. {
  11. /// <summary>
  12. /// excerpt of tensorflow/python/framework/ops_test.py
  13. /// # These cases test the private Graph._create_op_from_tf_operation
  14. /// # method. Arguably we should only test the public APIs that depend on this
  15. /// # method. However, this logic is complex and tricky, and it can be difficult to
  16. /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
  17. /// # the control flow context isn't set properly, but a more complicated use case
  18. /// # that might not be obvious to test will fail). Thus we instead explicitly test
  19. /// # the low-level behavior.
  20. /// </summary>
  21. [TestClass]
  22. public class CreateOpFromTfOperationTest : PythonTest
  23. {
  24. [TestMethod]
  25. public void TestShape()
  26. {
  27. var graph = tf.Graph().as_default();
  28. with<Graph>(graph, g =>
  29. {
  30. var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } });
  31. var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]);
  32. var op = g._create_op_from_tf_operation(c_op);
  33. Assert.AreEqual("myop", op.name);
  34. Assert.AreEqual("Identity", op.type);
  35. Assert.AreEqual(1, len(op.outputs));
  36. assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape);
  37. });
  38. }
  39. [TestMethod]
  40. public void TestUniqueName()
  41. {
  42. var graph = tf.Graph().as_default();
  43. with<Graph>(graph, g =>
  44. {
  45. //var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]);
  46. //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]);
  47. //var op = g._create_op_from_tf_operation(c_op);
  48. //var op2 = g._create_op_from_tf_operation(c_op2);
  49. var op = constant_op.constant(0, name: "myop").op;
  50. var op2 = constant_op.constant(0, name: "myop_1").op;
  51. // Create ops with same names as op1 and op2. We expect the new names to be
  52. // uniquified.
  53. var op3 = constant_op.constant(0, name: "myop").op;
  54. var op4 = constant_op.constant(0, name: "myop_1").op;
  55. self.assertEqual(op.name, "myop");
  56. self.assertEqual(op2.name, "myop_1");
  57. self.assertEqual(op3.name, "myop_2");
  58. self.assertEqual(op4.name, "myop_1_1");
  59. });
  60. }
  61. [Ignore("need tesnroflow expose UpdateEdge API")]
  62. [TestMethod]
  63. public void TestCond()
  64. {
  65. var graph = tf.Graph().as_default();
  66. with(graph, g =>
  67. {
  68. var x = constant_op.constant(10);
  69. var true_fn = new Func<Tensor>(() =>
  70. {
  71. var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
  72. var new_ops = g._add_new_tf_operations();
  73. self.assertEqual(len(new_ops), 1);
  74. return x;
  75. });
  76. control_flow_ops.cond(x < 10, true_fn, () => x);
  77. var op = g.get_operation_by_name("cond/myop");
  78. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true);
  79. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
  80. self.assertIsNotNone(op);
  81. self.assertEqual(op.name, "cond/myop");
  82. self.assertEqual(op.type, "Identity");
  83. //self.assertEqual(op.outputs, new object[0]);
  84. var op_input = op.inputs[0].op;
  85. self.assertEqual(op_input.type, "Switch");
  86. self.assertEqual(op_input.inputs[0].name, x.name);
  87. self.assertEqual(op.graph, g);
  88. self.assertIsNotNone(op._get_control_flow_context());
  89. var cond_text = op._get_control_flow_context() as ControlFlowContext;
  90. self.assertEqual(cond_text.name, "cond/cond_text");
  91. });
  92. }
  93. [Ignore("Todo: Port")]
  94. [TestMethod]
  95. public void TestWhileLoop()
  96. {
  97. var graph = tf.Graph().as_default();
  98. Operation x=null;
  99. with<Graph>(graph, g =>
  100. {
  101. x = constant_op.constant(42);
  102. var body = new Func<int, int>(i =>
  103. {
  104. ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] {x},
  105. new Operation[0]);
  106. var new_ops = g._add_new_tf_operations();
  107. self.assertEqual(len(new_ops), 1);
  108. return i;
  109. });
  110. // TODO: port control_flow_ops.while_loop
  111. //control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop");
  112. });
  113. var op = graph.get_operation_by_name("myloop/myop");
  114. self.assertIsNotNone(op);
  115. self.assertEqual(op.name, "myloop/myop");
  116. self.assertEqual(op.type, "Identity");
  117. self.assertEqual(op.outputs.Length, 0);
  118. var op_input = op.inputs[0].op;
  119. self.assertEqual(op_input.type, "Enter");
  120. self.assertItemsEqual(op_input.inputs.OfType<Operation>().ToArray(), new[] {x});
  121. self.assertEqual(op.graph, graph);
  122. self.assertIsNotNone(op._get_control_flow_context());
  123. self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).name, "myloop/while_context");
  124. /*
  125. @test_util.run_v1_only("b/120545219")
  126. def testWhileLoop(self):
  127. g = ops.Graph()
  128. with g.as_default():
  129. x = test_ops.int_output()
  130. def body(i):
  131. ops._create_c_op(ops.get_default_graph(),
  132. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  133. new_ops = g._add_new_tf_operations()
  134. self.assertEqual(len(new_ops), 1)
  135. return i
  136. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  137. op = g.get_operation_by_name("myloop/myop")
  138. self.assertIsNotNone(op)
  139. self.assertEqual(op.name, "myloop/myop")
  140. self.assertEqual(op.type, "IntInput")
  141. self.assertEqual(op.outputs, [])
  142. op_input = op.inputs[0].op
  143. self.assertEqual(op_input.type, "Enter")
  144. self.assertEqual(list(op_input.inputs), [x])
  145. self.assertEqual(op.graph, g)
  146. # pylint: disable=protected-access
  147. self.assertIsNotNone(op._get_control_flow_context())
  148. self.assertEqual(op._get_control_flow_context().name,
  149. "myloop/while_context")
  150. # pylint: enable=protected-access
  151. */
  152. }
  153. [Ignore("Todo: Port")]
  154. [TestMethod]
  155. public void TestWhileLoopWithInternalControlDep()
  156. {
  157. /*
  158. @test_util.run_v1_only("b/120545219")
  159. def testWhileLoopWithInternalControlDep(self):
  160. g = ops.Graph()
  161. with g.as_default():
  162. x = test_ops.int_output()
  163. def body(i):
  164. c = constant_op.constant(1.0, name="c")
  165. ops._create_c_op(ops.get_default_graph(),
  166. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  167. with ops.control_dependencies([c]):
  168. new_ops = g._add_new_tf_operations()
  169. self.assertEqual(len(new_ops), 1)
  170. return i
  171. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  172. op = g.get_operation_by_name("myloop/myop")
  173. self.assertIsNotNone(op)
  174. c = g.get_operation_by_name("myloop/c")
  175. self.assertIsNotNone(c)
  176. # Internal control dep is preserved
  177. self.assertEqual(op.control_inputs, [c])
  178. */
  179. }
  180. [Ignore("Todo: Port")]
  181. [TestMethod]
  182. public void TestWhileLoopWithExternalControlDep()
  183. {
  184. /*
  185. @test_util.run_v1_only("b/120545219")
  186. def testWhileLoopWithExternalControlDep(self):
  187. g = ops.Graph()
  188. with g.as_default():
  189. x = test_ops.int_output()
  190. c = constant_op.constant(1.0)
  191. def body(i):
  192. ops._create_c_op(ops.get_default_graph(),
  193. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  194. with ops.control_dependencies([c]):
  195. new_ops = g._add_new_tf_operations()
  196. self.assertEqual(len(new_ops), 1)
  197. return i
  198. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  199. op = g.get_operation_by_name("myloop/myop")
  200. self.assertIsNotNone(op)
  201. # External control dep is removed and replaced with internal control dep
  202. self.assertNotEqual(op.control_inputs[0], c.op)
  203. self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
  204. */
  205. }
  206. }
  207. }