You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlDependenciesTest.cs 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Linq;
  3. using Tensorflow;
  4. using Tensorflow.UnitTest;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.ops_test
  7. {
  8. /// <summary>
  9. /// excerpt of tensorflow/python/framework/ops_test.py
  10. /// </summary>
  11. [TestClass]
  12. public class ControlDependenciesTest : GraphModeTestBase
  13. {
  14. [TestMethod]
  15. public void TestBasic()
  16. {
  17. var g = tf.Graph().as_default();
  18. Tensor a = null, b = null, c = null, d = null, e = null;
  19. a = constant_op.constant(1.0);
  20. b = constant_op.constant(1.0);
  21. tf_with(g.control_dependencies(new[] { a }), x =>
  22. {
  23. c = constant_op.constant(1.0);
  24. d = array_ops.identity(b);
  25. e = array_ops.identity(c);
  26. });
  27. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op }));
  28. Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op }));
  29. // e should be dominated by c.
  30. Assert.AreEqual(0, e.op.control_inputs.Length);
  31. }
  32. [Ignore("How to port the ConvertibleObj?")]
  33. [TestMethod]
  34. public void TestBasicWithConversion()
  35. {
  36. var g = tf.Graph().as_default();
  37. // Note: _apply_op can be replaced by g.create_op
  38. var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  39. // TODO: ConvertibleObj, see original source below
  40. /*
  41. def testBasicWithConversion(self):
  42. g = ops.Graph()
  43. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  44. class ConvertibleObj(object):
  45. def _as_graph_element(self):
  46. return a
  47. with g.control_dependencies([ConvertibleObj()]):
  48. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  49. self.assertEqual(c.op.control_inputs, [a.op])
  50. */
  51. }
  52. [TestMethod]
  53. public void TestNested()
  54. {
  55. var g = tf.Graph().as_default();
  56. var a_1 = constant_op.constant(1.0);
  57. var a_2 = constant_op.constant(3.0);
  58. var a_3 = constant_op.constant(4.0);
  59. var a_4 = constant_op.constant(5.0);
  60. Tensor b_1 = null, b_2 = null;
  61. tf_with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
  62. {
  63. b_1 = constant_op.constant(6.0);
  64. });
  65. tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  66. {
  67. tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  68. {
  69. tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
  70. {
  71. tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
  72. {
  73. b_2 = constant_op.constant(7.0);
  74. });
  75. });
  76. });
  77. });
  78. //var z=tf.add(a_1, tf.multiply(b_2, b_1));
  79. //with(g.control_dependencies(new[] {z}), ctrl =>
  80. //{
  81. // var z1 = tf.add(a_3, tf.multiply(a_4, a_2));
  82. //});
  83. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
  84. assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
  85. assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
  86. }
  87. [TestMethod]
  88. public void TestClear()
  89. {
  90. var g = tf.Graph().as_default();
  91. var a_1 = constant_op.constant(1.0);
  92. var a_2 = constant_op.constant(3.0);
  93. var a_3 = constant_op.constant(4.0);
  94. var a_4 = constant_op.constant(5.0);
  95. Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null;
  96. tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  97. {
  98. tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  99. {
  100. tf_with(g.control_dependencies(null), ctrl3 =>
  101. {
  102. tf_with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
  103. {
  104. tf_with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
  105. {
  106. // deps [a_3, a_4]
  107. b_3_4 = constant_op.constant(7.0);
  108. });
  109. // deps = [a_3]
  110. b_3 = constant_op.constant(8.0);
  111. });
  112. // deps back to None
  113. b_none = constant_op.constant(9.0);
  114. });
  115. // deps back to [a_1, a_2]
  116. b_1_2 = constant_op.constant(10.0);
  117. });
  118. // deps back to [a_1]
  119. b_1 = constant_op.constant(11.0);
  120. tf_with(g.control_dependencies(null), ctrl6 =>
  121. {
  122. // deps are None again
  123. b_none2 = constant_op.constant(12.0);
  124. });
  125. });
  126. // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
  127. assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
  128. assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
  129. assertItemsEqual(new object[0], b_none.op.control_inputs);
  130. assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
  131. assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
  132. assertItemsEqual(new object[0], b_none2.op.control_inputs);
  133. }
  134. [TestMethod]
  135. public void TestComplex()
  136. {
  137. var g = tf.Graph().as_default();
  138. // Usage pattern:
  139. // * Nodes a_i are constants defined at the outermost scope, and are used
  140. // as control inputs for the ith nested scope.
  141. // * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
  142. // * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
  143. // * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
  144. // * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
  145. var a_1 = constant_op.constant(1.0);
  146. var a_2 = constant_op.constant(2.0);
  147. var a_3 = constant_op.constant(3.0);
  148. var a_4 = constant_op.constant(4.0);
  149. Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null;
  150. Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null;
  151. Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null;
  152. Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null;
  153. tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  154. {
  155. b_1 = tf.multiply(a_3, a_4);
  156. c_1 = tf.multiply(a_1, b_1.output);
  157. d_1 = tf.multiply(b_1.output, c_1.output);
  158. e_1 = constant_op.constant(5.0);
  159. tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  160. {
  161. b_2 = tf.multiply(a_3, a_4);
  162. c_2 = tf.multiply(a_1, b_1.output);
  163. d_2 = tf.multiply(b_2.output, c_2.output);
  164. e_2 = tf.multiply(e_1.output, e_1.output);
  165. tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
  166. {
  167. b_3 = tf.multiply(a_3, a_4);
  168. c_3 = tf.multiply(a_1, b_1.output);
  169. d_3 = tf.multiply(b_3.output, c_3.output);
  170. e_3 = tf.multiply(e_2.output, e_2.output);
  171. tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
  172. {
  173. b_4 = tf.multiply(a_3, a_4);
  174. c_4 = tf.multiply(a_1, b_1.output);
  175. d_4 = tf.multiply(b_4.output, c_4.output);
  176. e_4 = tf.multiply(e_3.output, e_3.output);
  177. });
  178. });
  179. });
  180. });
  181. // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
  182. assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
  183. assertItemsEqual(new[] { a_1.op, a_2.op }, b_2.op.control_inputs);
  184. assertItemsEqual(new[] { a_1.op, a_2.op }, b_3.op.control_inputs);
  185. assertItemsEqual(new[] { a_1.op, a_2.op }, b_4.op.control_inputs);
  186. assertItemsEqual(new object[0], c_1.op.control_inputs);
  187. assertItemsEqual(new[] { a_2.op }, c_2.op.control_inputs);
  188. assertItemsEqual(new[] { a_2.op, a_3.op }, c_3.op.control_inputs);
  189. assertItemsEqual(new[] { a_2.op, a_3.op, a_4.op }, c_4.op.control_inputs);
  190. assertItemsEqual(new object[0], d_1.op.control_inputs);
  191. assertItemsEqual(new object[0], d_2.op.control_inputs);
  192. assertItemsEqual(new object[0], d_3.op.control_inputs);
  193. assertItemsEqual(new object[0], d_4.op.control_inputs);
  194. assertItemsEqual(new[] { a_1.op }, e_1.op.control_inputs);
  195. assertItemsEqual(new[] { a_2.op }, e_2.op.control_inputs);
  196. assertItemsEqual(new[] { a_3.op }, e_3.op.control_inputs);
  197. assertItemsEqual(new[] { a_4.op }, e_4.op.control_inputs);
  198. }
  199. [Ignore("Don't know how to create an operation with two outputs")]
  200. [TestMethod]
  201. public void TestRepeatedDependency()
  202. {
  203. /*
  204. def testRepeatedDependency(self):
  205. g = ops.Graph()
  206. a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
  207. a_0, a_1 = a.outputs
  208. with g.control_dependencies([a_0]):
  209. b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  210. with g.control_dependencies([a_1]):
  211. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  212. self.assertEqual(b.op.control_inputs, [a])
  213. self.assertEqual(c.op.control_inputs, [a])
  214. */
  215. }
  216. [TestMethod]
  217. public void TestNoControlDependencyWithDataDependency()
  218. {
  219. var g = tf.Graph().as_default();
  220. Operation b = null;
  221. var a = constant_op.constant(100.0);
  222. tf_with(g.control_dependencies(new[] { a }), ctrl1 =>
  223. {
  224. b = array_ops.identity(a);
  225. });
  226. Assert.AreEqual(0, b.op.control_inputs.Length);
  227. }
  228. }
  229. }