You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlDependenciesTest.cs 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. using System;
  2. using System.Linq;
  3. using Microsoft.VisualStudio.TestTools.UnitTesting;
  4. using Tensorflow;
  5. using Tensorflow.Eager;
  6. using Tensorflow.UnitTest;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest.ops_test
  9. {
  10. /// <summary>
  11. /// excerpt of tensorflow/python/framework/ops_test.py
  12. /// </summary>
  13. [TestClass]
  14. public class ControlDependenciesTest : GraphModeTestBase
  15. {
  16. [TestMethod]
  17. public void TestBasic()
  18. {
  19. var g = tf.Graph().as_default();
  20. Tensor a = null, b = null, c = null, d = null, e = null;
  21. a = constant_op.constant(1.0);
  22. b = constant_op.constant(1.0);
  23. tf_with(g.control_dependencies(new[] { a }), x =>
  24. {
  25. c = constant_op.constant(1.0);
  26. d = array_ops.identity(b);
  27. e = array_ops.identity(c);
  28. });
  29. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op }));
  30. Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op }));
  31. // e should be dominated by c.
  32. Assert.AreEqual(0, e.op.control_inputs.Length);
  33. }
  34. [Ignore("How to port the ConvertibleObj?")]
  35. [TestMethod]
  36. public void TestBasicWithConversion()
  37. {
  38. var g = tf.Graph().as_default();
  39. // Note: _apply_op can be replaced by g.create_op
  40. var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  41. // TODO: ConvertibleObj, see original source below
  42. /*
  43. def testBasicWithConversion(self):
  44. g = ops.Graph()
  45. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  46. class ConvertibleObj(object):
  47. def _as_graph_element(self):
  48. return a
  49. with g.control_dependencies([ConvertibleObj()]):
  50. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  51. self.assertEqual(c.op.control_inputs, [a.op])
  52. */
  53. }
  54. [TestMethod]
  55. public void TestNested()
  56. {
  57. var g = tf.Graph().as_default();
  58. var a_1 = constant_op.constant(1.0);
  59. var a_2 = constant_op.constant(3.0);
  60. var a_3 = constant_op.constant(4.0);
  61. var a_4 = constant_op.constant(5.0);
  62. Tensor b_1 = null, b_2 = null;
  63. tf_with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
  64. {
  65. b_1 = constant_op.constant(6.0);
  66. });
  67. tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  68. {
  69. tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  70. {
  71. tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
  72. {
  73. tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
  74. {
  75. b_2 = constant_op.constant(7.0);
  76. });
  77. });
  78. });
  79. });
  80. //var z=tf.add(a_1, tf.multiply(b_2, b_1));
  81. //with(g.control_dependencies(new[] {z}), ctrl =>
  82. //{
  83. // var z1 = tf.add(a_3, tf.multiply(a_4, a_2));
  84. //});
  85. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
  86. assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
  87. assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
  88. }
  89. [TestMethod]
  90. public void TestClear()
  91. {
  92. var g = tf.Graph().as_default();
  93. var a_1 = constant_op.constant(1.0);
  94. var a_2 = constant_op.constant(3.0);
  95. var a_3 = constant_op.constant(4.0);
  96. var a_4 = constant_op.constant(5.0);
  97. Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null;
  98. tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  99. {
  100. tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  101. {
  102. tf_with(g.control_dependencies(null), ctrl3 =>
  103. {
  104. tf_with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
  105. {
  106. tf_with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
  107. {
  108. // deps [a_3, a_4]
  109. b_3_4 = constant_op.constant(7.0);
  110. });
  111. // deps = [a_3]
  112. b_3 = constant_op.constant(8.0);
  113. });
  114. // deps back to None
  115. b_none = constant_op.constant(9.0);
  116. });
  117. // deps back to [a_1, a_2]
  118. b_1_2 = constant_op.constant(10.0);
  119. });
  120. // deps back to [a_1]
  121. b_1 = constant_op.constant(11.0);
  122. tf_with(g.control_dependencies(null), ctrl6 =>
  123. {
  124. // deps are None again
  125. b_none2 = constant_op.constant(12.0);
  126. });
  127. });
  128. // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
  129. assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
  130. assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
  131. assertItemsEqual(new object[0], b_none.op.control_inputs);
  132. assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
  133. assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
  134. assertItemsEqual(new object[0], b_none2.op.control_inputs);
  135. }
  136. [TestMethod]
  137. public void TestComplex()
  138. {
  139. var g = tf.Graph().as_default();
  140. // Usage pattern:
  141. // * Nodes a_i are constants defined at the outermost scope, and are used
  142. // as control inputs for the ith nested scope.
  143. // * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
  144. // * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
  145. // * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
  146. // * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
  147. var a_1 = constant_op.constant(1.0);
  148. var a_2 = constant_op.constant(2.0);
  149. var a_3 = constant_op.constant(3.0);
  150. var a_4 = constant_op.constant(4.0);
  151. Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null;
  152. Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null;
  153. Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null;
  154. Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null;
  155. tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  156. {
  157. b_1 = tf.multiply(a_3, a_4);
  158. c_1 = tf.multiply(a_1, b_1.output);
  159. d_1 = tf.multiply(b_1.output, c_1.output);
  160. e_1 = constant_op.constant(5.0);
  161. tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  162. {
  163. b_2 = tf.multiply(a_3, a_4);
  164. c_2 = tf.multiply(a_1, b_1.output);
  165. d_2 = tf.multiply(b_2.output, c_2.output);
  166. e_2 = tf.multiply(e_1.output, e_1.output);
  167. tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
  168. {
  169. b_3 = tf.multiply(a_3, a_4);
  170. c_3 = tf.multiply(a_1, b_1.output);
  171. d_3 = tf.multiply(b_3.output, c_3.output);
  172. e_3 = tf.multiply(e_2.output, e_2.output);
  173. tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
  174. {
  175. b_4 = tf.multiply(a_3, a_4);
  176. c_4 = tf.multiply(a_1, b_1.output);
  177. d_4 = tf.multiply(b_4.output, c_4.output);
  178. e_4 = tf.multiply(e_3.output, e_3.output);
  179. });
  180. });
  181. });
  182. });
  183. // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
  184. assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
  185. assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
  186. assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);
  187. assertItemsEqual(new[] {a_1.op, a_2.op}, b_4.op.control_inputs);
  188. assertItemsEqual(new object[0], c_1.op.control_inputs);
  189. assertItemsEqual(new[] {a_2.op}, c_2.op.control_inputs);
  190. assertItemsEqual(new[] {a_2.op, a_3.op}, c_3.op.control_inputs);
  191. assertItemsEqual(new[] {a_2.op, a_3.op, a_4.op}, c_4.op.control_inputs);
  192. assertItemsEqual(new object[0], d_1.op.control_inputs);
  193. assertItemsEqual(new object[0], d_2.op.control_inputs);
  194. assertItemsEqual(new object[0], d_3.op.control_inputs);
  195. assertItemsEqual(new object[0], d_4.op.control_inputs);
  196. assertItemsEqual(new[] {a_1.op}, e_1.op.control_inputs);
  197. assertItemsEqual(new[] {a_2.op}, e_2.op.control_inputs);
  198. assertItemsEqual(new[] {a_3.op}, e_3.op.control_inputs);
  199. assertItemsEqual(new[] {a_4.op}, e_4.op.control_inputs);
  200. }
  201. [Ignore("Don't know how to create an operation with two outputs")]
  202. [TestMethod]
  203. public void TestRepeatedDependency()
  204. {
  205. /*
  206. def testRepeatedDependency(self):
  207. g = ops.Graph()
  208. a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
  209. a_0, a_1 = a.outputs
  210. with g.control_dependencies([a_0]):
  211. b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  212. with g.control_dependencies([a_1]):
  213. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  214. self.assertEqual(b.op.control_inputs, [a])
  215. self.assertEqual(c.op.control_inputs, [a])
  216. */
  217. }
  218. [TestMethod]
  219. public void TestNoControlDependencyWithDataDependency()
  220. {
  221. var g = tf.Graph().as_default();
  222. Operation b = null;
  223. var a = constant_op.constant(100.0);
  224. tf_with(g.control_dependencies(new[] { a }), ctrl1 =>
  225. {
  226. b = array_ops.identity(a);
  227. });
  228. Assert.AreEqual(0, b.op.control_inputs.Length);
  229. }
  230. }
  231. }