You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CreateOpFromTfOperationTest.cs 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Microsoft.VisualStudio.TestTools.UnitTesting;
  5. using Tensorflow;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. /// <summary>
  9. /// excerpt of tensorflow/python/framework/ops_test.py
  10. /// # These cases test the private Graph._create_op_from_tf_operation
  11. /// # method. Arguably we should only test the public APIs that depend on this
  12. /// # method. However, this logic is complex and tricky, and it can be difficult to
  13. /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
  14. /// # the control flow context isn't set properly, but a more complicated use case
  15. /// # that might not be obvious to test will fail). Thus we instead explicitly test
  16. /// # the low-level behavior.
  17. /// </summary>
  18. [TestClass]
  19. public class CreateOpFromTfOperationTest : PythonTest
  20. {
  21. [TestMethod]
  22. public void TestShape()
  23. {
  24. var graph = tf.Graph().as_default();
  25. with<Graph>(graph, g =>
  26. {
  27. var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}});
  28. var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
  29. var op = g._create_op_from_tf_operation(c_op);
  30. Assert.AreEqual("myop", op.name);
  31. Assert.AreEqual("Identity", op.type);
  32. Assert.AreEqual(1, len(op.outputs));
  33. assertItemsEqual(new []{2, 3}, op.outputs[0].shape);
  34. });
  35. }
  36. [TestMethod]
  37. public void TestUniqueName()
  38. {
  39. var graph = tf.Graph().as_default();
  40. with<Graph>(graph, g =>
  41. {
  42. //var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]);
  43. //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]);
  44. //var op = g._create_op_from_tf_operation(c_op);
  45. //var op2 = g._create_op_from_tf_operation(c_op2);
  46. var op = constant_op.constant(0, name:"myop").op;
  47. var op2 = constant_op.constant(0, name: "myop_1").op;
  48. // Create ops with same names as op1 and op2. We expect the new names to be
  49. // uniquified.
  50. var op3 = constant_op.constant(0, name: "myop").op;
  51. var op4 = constant_op.constant(0, name: "myop_1").op;
  52. self.assertEqual(op.name, "myop");
  53. self.assertEqual(op2.name, "myop_1");
  54. self.assertEqual(op3.name, "myop_2");
  55. self.assertEqual(op4.name, "myop_1_1");
  56. });
  57. }
  58. [Ignore("Something is not right, Switch gets not inserted correctly?")]
  59. [TestMethod]
  60. public void TestCond()
  61. {
  62. var graph = tf.Graph().as_default();
  63. with<Graph>(graph, g =>
  64. {
  65. var x = constant_op.constant(10);
  66. var true_fn = new Func<Tensor>(() =>
  67. {
  68. var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
  69. var new_ops = g._add_new_tf_operations();
  70. self.assertEqual(len(new_ops), 1);
  71. return x;
  72. });
  73. control_flow_ops.cond(x < 10, true_fn, () => x);
  74. var op = g.get_operation_by_name("cond/myop");
  75. self.assertIsNotNone(op);
  76. self.assertEqual(op.name, "cond/myop");
  77. self.assertEqual(op.type, "Identity");
  78. //self.assertEqual(op.outputs, new object[0]);
  79. var op_input = op.inputs[0].op;
  80. self.assertEqual(op_input.type, "Switch");
  81. self.assertEqual(op_input.inputs[0], x);
  82. self.assertEqual(op.graph, g);
  83. self.assertIsNotNone(op._get_control_flow_context());
  84. // TODO: op._get_control_flow_context().name not implemented
  85. //self.assertEqual(op._get_control_flow_context().name, "cond/cond_text");
  86. });
  87. /*
  88. @test_util.run_v1_only("b/120545219")
  89. def testCond(self):
  90. g = ops.Graph()
  91. with g.as_default():
  92. x = test_ops.int_output()
  93. def true_fn():
  94. ops._create_c_op(ops.get_default_graph(),
  95. ops._NodeDef("IntInput", "cond/myop"), [x], [])
  96. new_ops = g._add_new_tf_operations()
  97. self.assertEqual(len(new_ops), 1)
  98. return x
  99. control_flow_ops.cond(x < 10, true_fn, lambda: x)
  100. op = g.get_operation_by_name("cond/myop")
  101. self.assertIsNotNone(op)
  102. self.assertEqual(op.name, "cond/myop")
  103. self.assertEqual(op.type, "IntInput")
  104. self.assertEqual(op.outputs, [])
  105. op_input = op.inputs[0].op
  106. self.assertEqual(op_input.type, "Switch")
  107. self.assertEqual(op_input.inputs[0], x)
  108. self.assertEqual(op.graph, g)
  109. # pylint: disable=protected-access
  110. self.assertIsNotNone(op._get_control_flow_context())
  111. self.assertEqual(op._get_control_flow_context().name,
  112. "cond/cond_text")
  113. # pylint: enable=protected-access
  114. */
  115. }
  116. /*
  117. @test_util.run_v1_only("b/120545219")
  118. def testWhileLoop(self):
  119. g = ops.Graph()
  120. with g.as_default():
  121. x = test_ops.int_output()
  122. def body(i):
  123. ops._create_c_op(ops.get_default_graph(),
  124. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  125. new_ops = g._add_new_tf_operations()
  126. self.assertEqual(len(new_ops), 1)
  127. return i
  128. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  129. op = g.get_operation_by_name("myloop/myop")
  130. self.assertIsNotNone(op)
  131. self.assertEqual(op.name, "myloop/myop")
  132. self.assertEqual(op.type, "IntInput")
  133. self.assertEqual(op.outputs, [])
  134. op_input = op.inputs[0].op
  135. self.assertEqual(op_input.type, "Enter")
  136. self.assertEqual(list(op_input.inputs), [x])
  137. self.assertEqual(op.graph, g)
  138. # pylint: disable=protected-access
  139. self.assertIsNotNone(op._get_control_flow_context())
  140. self.assertEqual(op._get_control_flow_context().name,
  141. "myloop/while_context")
  142. # pylint: enable=protected-access
  143. @test_util.run_v1_only("b/120545219")
  144. def testWhileLoopWithInternalControlDep(self):
  145. g = ops.Graph()
  146. with g.as_default():
  147. x = test_ops.int_output()
  148. def body(i):
  149. c = constant_op.constant(1.0, name="c")
  150. ops._create_c_op(ops.get_default_graph(),
  151. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  152. with ops.control_dependencies([c]):
  153. new_ops = g._add_new_tf_operations()
  154. self.assertEqual(len(new_ops), 1)
  155. return i
  156. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  157. op = g.get_operation_by_name("myloop/myop")
  158. self.assertIsNotNone(op)
  159. c = g.get_operation_by_name("myloop/c")
  160. self.assertIsNotNone(c)
  161. # Internal control dep is preserved
  162. self.assertEqual(op.control_inputs, [c])
  163. @test_util.run_v1_only("b/120545219")
  164. def testWhileLoopWithExternalControlDep(self):
  165. g = ops.Graph()
  166. with g.as_default():
  167. x = test_ops.int_output()
  168. c = constant_op.constant(1.0)
  169. def body(i):
  170. ops._create_c_op(ops.get_default_graph(),
  171. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  172. with ops.control_dependencies([c]):
  173. new_ops = g._add_new_tf_operations()
  174. self.assertEqual(len(new_ops), 1)
  175. return i
  176. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  177. op = g.get_operation_by_name("myloop/myop")
  178. self.assertIsNotNone(op)
  179. # External control dep is removed and replaced with internal control dep
  180. self.assertNotEqual(op.control_inputs[0], c.op)
  181. self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
  182. */
  183. }
  184. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。