You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlDependenciesTest.cs 14 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. using System;
  2. using System.Linq;
  3. using Microsoft.VisualStudio.TestTools.UnitTesting;
  4. using Tensorflow;
  5. using Tensorflow.Eager;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.ops_test
  8. {
  9. /// <summary>
  10. /// excerpt of tensorflow/python/framework/ops_test.py
  11. /// </summary>
  12. [TestClass]
  13. public class ControlDependenciesTest : PythonTest
  14. {
  15. [TestMethod]
  16. public void TestBasic()
  17. {
  18. var g = tf.Graph().as_default();
  19. Tensor a = null, b = null, c = null, d = null, e = null;
  20. a = constant_op.constant(1.0);
  21. b = constant_op.constant(1.0);
  22. tf_with(g.control_dependencies(new[] { a }), x =>
  23. {
  24. c = constant_op.constant(1.0);
  25. d = array_ops.identity(b);
  26. e = array_ops.identity(c);
  27. });
  28. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op }));
  29. Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op }));
  30. // e should be dominated by c.
  31. Assert.AreEqual(0, e.op.control_inputs.Length);
  32. }
  33. [Ignore("Future is not supported yet")]
  34. [TestMethod]
  35. public void TestEager()
  36. {
  37. Tensor a = null, c = null;
  38. object b = null;
  39. var calls = 0;
  40. Func<Tensor> future = () =>
  41. {
  42. calls += 1;
  43. return constant_op.constant(2.0);
  44. };
  45. using (var opts = new ContextOptions())
  46. using (var status = new Status())
  47. using (var context = new Context(opts, status))
  48. {
  49. if (context.executing_eagerly())
  50. {
  51. // TODO: make this compile (see original Python code below)
  52. a = constant_op.constant(1.0);
  53. b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
  54. tf_with(ops.control_dependencies(new object[] { a, b }), ctrl =>
  55. {
  56. return c = constant_op.constant(3.0);
  57. });
  58. Assert.AreEqual(calls, 1);
  59. }
  60. else
  61. {
  62. var g = tf.Graph().as_default();
  63. a = constant_op.constant(1.0);
  64. var b1 = future();
  65. tf_with(g.control_dependencies(new[] { a, b }), ctrl =>
  66. {
  67. c = constant_op.constant(3.0);
  68. });
  69. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b1.op }));
  70. Assert.AreEqual(1, calls);
  71. }
  72. }
  73. /*
  74. def testEager(self):
  75. def future():
  76. future.calls += 1
  77. return constant_op.constant(2.0)
  78. future.calls = 0
  79. if context.executing_eagerly():
  80. a = constant_op.constant(1.0)
  81. b = future
  82. with ops.control_dependencies([a, b]):
  83. c = constant_op.constant(3.0)
  84. self.assertEqual(future.calls, 1)
  85. else:
  86. g = ops.Graph()
  87. with g.as_default():
  88. a = constant_op.constant(1.0)
  89. b = future()
  90. with g.control_dependencies([a, b]):
  91. c = constant_op.constant(3.0)
  92. self.assertEqual(c.op.control_inputs, [a.op, b.op])
  93. self.assertEqual(future.calls, 1)
  94. */
  95. }
  96. [Ignore("How to port the ConvertibleObj?")]
  97. [TestMethod]
  98. public void TestBasicWithConversion()
  99. {
  100. var g = tf.Graph().as_default();
  101. // Note: _apply_op can be replaced by g.create_op
  102. var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  103. // TODO: ConvertibleObj, see original source below
  104. /*
  105. def testBasicWithConversion(self):
  106. g = ops.Graph()
  107. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  108. class ConvertibleObj(object):
  109. def _as_graph_element(self):
  110. return a
  111. with g.control_dependencies([ConvertibleObj()]):
  112. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  113. self.assertEqual(c.op.control_inputs, [a.op])
  114. */
  115. }
  116. [TestMethod]
  117. public void TestNested()
  118. {
  119. var g = tf.Graph().as_default();
  120. var a_1 = constant_op.constant(1.0);
  121. var a_2 = constant_op.constant(3.0);
  122. var a_3 = constant_op.constant(4.0);
  123. var a_4 = constant_op.constant(5.0);
  124. Tensor b_1 = null, b_2 = null;
  125. tf_with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
  126. {
  127. b_1 = constant_op.constant(6.0);
  128. });
  129. tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  130. {
  131. tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  132. {
  133. tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
  134. {
  135. tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
  136. {
  137. b_2 = constant_op.constant(7.0);
  138. });
  139. });
  140. });
  141. });
  142. //var z=tf.add(a_1, tf.multiply(b_2, b_1));
  143. //with(g.control_dependencies(new[] {z}), ctrl =>
  144. //{
  145. // var z1 = tf.add(a_3, tf.multiply(a_4, a_2));
  146. //});
  147. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
  148. assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
  149. assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
  150. }
  151. [TestMethod]
  152. public void TestClear()
  153. {
  154. var g = tf.Graph().as_default();
  155. var a_1 = constant_op.constant(1.0);
  156. var a_2 = constant_op.constant(3.0);
  157. var a_3 = constant_op.constant(4.0);
  158. var a_4 = constant_op.constant(5.0);
  159. Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null;
  160. tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  161. {
  162. tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  163. {
  164. tf_with(g.control_dependencies(null), ctrl3 =>
  165. {
  166. tf_with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
  167. {
  168. tf_with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
  169. {
  170. // deps [a_3, a_4]
  171. b_3_4 = constant_op.constant(7.0);
  172. });
  173. // deps = [a_3]
  174. b_3 = constant_op.constant(8.0);
  175. });
  176. // deps back to None
  177. b_none = constant_op.constant(9.0);
  178. });
  179. // deps back to [a_1, a_2]
  180. b_1_2 = constant_op.constant(10.0);
  181. });
  182. // deps back to [a_1]
  183. b_1 = constant_op.constant(11.0);
  184. tf_with(g.control_dependencies(null), ctrl6 =>
  185. {
  186. // deps are None again
  187. b_none2 = constant_op.constant(12.0);
  188. });
  189. });
  190. // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
  191. assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
  192. assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
  193. assertItemsEqual(new object[0], b_none.op.control_inputs);
  194. assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
  195. assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
  196. assertItemsEqual(new object[0], b_none2.op.control_inputs);
  197. }
  198. [TestMethod]
  199. public void TestComplex()
  200. {
  201. var g = tf.Graph().as_default();
  202. // Usage pattern:
  203. // * Nodes a_i are constants defined at the outermost scope, and are used
  204. // as control inputs for the ith nested scope.
  205. // * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
  206. // * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
  207. // * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
  208. // * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
  209. var a_1 = constant_op.constant(1.0);
  210. var a_2 = constant_op.constant(2.0);
  211. var a_3 = constant_op.constant(3.0);
  212. var a_4 = constant_op.constant(4.0);
  213. Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null;
  214. Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null;
  215. Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null;
  216. Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null;
  217. tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  218. {
  219. b_1 = tf.multiply(a_3, a_4);
  220. c_1 = tf.multiply(a_1, b_1.output);
  221. d_1 = tf.multiply(b_1.output, c_1.output);
  222. e_1 = constant_op.constant(5.0);
  223. tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  224. {
  225. b_2 = tf.multiply(a_3, a_4);
  226. c_2 = tf.multiply(a_1, b_1.output);
  227. d_2 = tf.multiply(b_2.output, c_2.output);
  228. e_2 = tf.multiply(e_1.output, e_1.output);
  229. tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
  230. {
  231. b_3 = tf.multiply(a_3, a_4);
  232. c_3 = tf.multiply(a_1, b_1.output);
  233. d_3 = tf.multiply(b_3.output, c_3.output);
  234. e_3 = tf.multiply(e_2.output, e_2.output);
  235. tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
  236. {
  237. b_4 = tf.multiply(a_3, a_4);
  238. c_4 = tf.multiply(a_1, b_1.output);
  239. d_4 = tf.multiply(b_4.output, c_4.output);
  240. e_4 = tf.multiply(e_3.output, e_3.output);
  241. });
  242. });
  243. });
  244. });
  245. // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
  246. assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
  247. assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
  248. assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);
  249. assertItemsEqual(new[] {a_1.op, a_2.op}, b_4.op.control_inputs);
  250. assertItemsEqual(new object[0], c_1.op.control_inputs);
  251. assertItemsEqual(new[] {a_2.op}, c_2.op.control_inputs);
  252. assertItemsEqual(new[] {a_2.op, a_3.op}, c_3.op.control_inputs);
  253. assertItemsEqual(new[] {a_2.op, a_3.op, a_4.op}, c_4.op.control_inputs);
  254. assertItemsEqual(new object[0], d_1.op.control_inputs);
  255. assertItemsEqual(new object[0], d_2.op.control_inputs);
  256. assertItemsEqual(new object[0], d_3.op.control_inputs);
  257. assertItemsEqual(new object[0], d_4.op.control_inputs);
  258. assertItemsEqual(new[] {a_1.op}, e_1.op.control_inputs);
  259. assertItemsEqual(new[] {a_2.op}, e_2.op.control_inputs);
  260. assertItemsEqual(new[] {a_3.op}, e_3.op.control_inputs);
  261. assertItemsEqual(new[] {a_4.op}, e_4.op.control_inputs);
  262. }
  263. [Ignore("Don't know how to create an operation with two outputs")]
  264. [TestMethod]
  265. public void TestRepeatedDependency()
  266. {
  267. /*
  268. def testRepeatedDependency(self):
  269. g = ops.Graph()
  270. a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
  271. a_0, a_1 = a.outputs
  272. with g.control_dependencies([a_0]):
  273. b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  274. with g.control_dependencies([a_1]):
  275. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  276. self.assertEqual(b.op.control_inputs, [a])
  277. self.assertEqual(c.op.control_inputs, [a])
  278. */
  279. }
  280. [TestMethod]
  281. public void TestNoControlDependencyWithDataDependency()
  282. {
  283. var g = tf.Graph().as_default();
  284. Operation b = null;
  285. var a = constant_op.constant(100.0);
  286. tf_with(g.control_dependencies(new[] { a }), ctrl1 =>
  287. {
  288. b = array_ops.identity(a);
  289. });
  290. Assert.AreEqual(0, b.op.control_inputs.Length);
  291. }
  292. }
  293. }