You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlDependenciesTest.cs 14 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Microsoft.VisualStudio.TestTools.UnitTesting;
  6. using Tensorflow;
  7. using Tensorflow.Eager;
  8. using static Tensorflow.Python;
  9. namespace TensorFlowNET.UnitTest.ops_test
  10. {
  11. /// <summary>
  12. /// excerpt of tensorflow/python/framework/ops_test.py
  13. /// </summary>
  14. [TestClass]
  15. public class ControlDependenciesTest : PythonTest
  16. {
  17. [TestMethod]
  18. public void TestBasic()
  19. {
  20. var graph = tf.Graph().as_default();
  21. Tensor a = null, b = null, c = null, d = null, e = null;
  22. with<Graph>(graph, g =>
  23. {
  24. a = constant_op.constant(1.0);
  25. b = constant_op.constant(1.0);
  26. with(g.control_dependencies(new[] { a }), x =>
  27. {
  28. c = constant_op.constant(1.0);
  29. d = array_ops.identity(b);
  30. e = array_ops.identity(c);
  31. });
  32. });
  33. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op }));
  34. Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op }));
  35. // e should be dominated by c.
  36. Assert.AreEqual(0, e.op.control_inputs.Length);
  37. }
  38. [Ignore("Future is not supported yet")]
  39. [TestMethod]
  40. public void TestEager()
  41. {
  42. Tensor a = null, c = null, d = null, e = null;
  43. object b = null;
  44. var calls = 0;
  45. Func<Tensor> future = () =>
  46. {
  47. calls += 1;
  48. return constant_op.constant(2.0);
  49. };
  50. using (var opts = new ContextOptions())
  51. using (var status = new Status())
  52. using (var context = new Context(opts, status))
  53. {
  54. if (context.executing_eagerly())
  55. {
  56. // TODO: make this compile (see original Python code below)
  57. a = constant_op.constant(1.0);
  58. b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
  59. with(ops.control_dependencies(new object[] { a, b }), ctrl =>
  60. {
  61. return c = constant_op.constant(3.0);
  62. });
  63. Assert.AreEqual(calls, 1);
  64. }
  65. else
  66. {
  67. var graph = tf.Graph().as_default();
  68. with<Graph>(graph, g =>
  69. {
  70. a = constant_op.constant(1.0);
  71. var b1 = future();
  72. with(g.control_dependencies(new[] { a, b }), ctrl =>
  73. {
  74. c = constant_op.constant(3.0);
  75. });
  76. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b1.op }));
  77. Assert.AreEqual(1, calls);
  78. });
  79. }
  80. }
  81. /*
  82. def testEager(self):
  83. def future():
  84. future.calls += 1
  85. return constant_op.constant(2.0)
  86. future.calls = 0
  87. if context.executing_eagerly():
  88. a = constant_op.constant(1.0)
  89. b = future
  90. with ops.control_dependencies([a, b]):
  91. c = constant_op.constant(3.0)
  92. self.assertEqual(future.calls, 1)
  93. else:
  94. g = ops.Graph()
  95. with g.as_default():
  96. a = constant_op.constant(1.0)
  97. b = future()
  98. with g.control_dependencies([a, b]):
  99. c = constant_op.constant(3.0)
  100. self.assertEqual(c.op.control_inputs, [a.op, b.op])
  101. self.assertEqual(future.calls, 1)
  102. */
  103. }
  104. [Ignore("How to port the ConvertibleObj?")]
  105. [TestMethod]
  106. public void TestBasicWithConversion()
  107. {
  108. var g = tf.Graph().as_default();
  109. // Note: _apply_op can be replaced by g.create_op
  110. var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  111. // TODO: ConvertibleObj, see original source below
  112. /*
  113. def testBasicWithConversion(self):
  114. g = ops.Graph()
  115. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  116. class ConvertibleObj(object):
  117. def _as_graph_element(self):
  118. return a
  119. with g.control_dependencies([ConvertibleObj()]):
  120. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  121. self.assertEqual(c.op.control_inputs, [a.op])
  122. */
  123. }
  124. [TestMethod]
  125. public void TestNested()
  126. {
  127. var g = tf.Graph().as_default();
  128. var a_1 = constant_op.constant(1.0);
  129. var a_2 = constant_op.constant(3.0);
  130. var a_3 = constant_op.constant(4.0);
  131. var a_4 = constant_op.constant(5.0);
  132. Tensor b_1 = null, b_2 = null;
  133. with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
  134. {
  135. b_1 = constant_op.constant(6.0);
  136. });
  137. with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  138. {
  139. with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  140. {
  141. with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
  142. {
  143. with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
  144. {
  145. b_2 = constant_op.constant(7.0);
  146. });
  147. });
  148. });
  149. });
  150. //var z=tf.add(a_1, tf.multiply(b_2, b_1));
  151. //with(g.control_dependencies(new[] {z}), ctrl =>
  152. //{
  153. // var z1 = tf.add(a_3, tf.multiply(a_4, a_2));
  154. //});
  155. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
  156. assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
  157. assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
  158. }
  159. [TestMethod]
  160. public void TestClear()
  161. {
  162. var g = tf.Graph().as_default();
  163. var a_1 = constant_op.constant(1.0);
  164. var a_2 = constant_op.constant(3.0);
  165. var a_3 = constant_op.constant(4.0);
  166. var a_4 = constant_op.constant(5.0);
  167. Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null;
  168. with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  169. {
  170. with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  171. {
  172. with(g.control_dependencies(null), ctrl3 =>
  173. {
  174. with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
  175. {
  176. with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
  177. {
  178. // deps [a_3, a_4]
  179. b_3_4 = constant_op.constant(7.0);
  180. });
  181. // deps = [a_3]
  182. b_3 = constant_op.constant(8.0);
  183. });
  184. // deps back to None
  185. b_none = constant_op.constant(9.0);
  186. });
  187. // deps back to [a_1, a_2]
  188. b_1_2 = constant_op.constant(10.0);
  189. });
  190. // deps back to [a_1]
  191. b_1 = constant_op.constant(11.0);
  192. with(g.control_dependencies(null), ctrl6 =>
  193. {
  194. // deps are None again
  195. b_none2 = constant_op.constant(12.0);
  196. });
  197. });
  198. // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
  199. assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
  200. assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
  201. assertItemsEqual(new object[0], b_none.op.control_inputs);
  202. assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
  203. assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
  204. assertItemsEqual(new object[0], b_none2.op.control_inputs);
  205. }
  206. [TestMethod]
  207. public void TestComplex()
  208. {
  209. var g = tf.Graph().as_default();
  210. // Usage pattern:
  211. // * Nodes a_i are constants defined at the outermost scope, and are used
  212. // as control inputs for the ith nested scope.
  213. // * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
  214. // * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
  215. // * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
  216. // * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
  217. var a_1 = constant_op.constant(1.0);
  218. var a_2 = constant_op.constant(2.0);
  219. var a_3 = constant_op.constant(3.0);
  220. var a_4 = constant_op.constant(4.0);
  221. Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null;
  222. Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null;
  223. Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null;
  224. Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null;
  225. with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  226. {
  227. b_1 = tf.multiply(a_3, a_4);
  228. c_1 = tf.multiply(a_1, b_1.output);
  229. d_1 = tf.multiply(b_1.output, c_1.output);
  230. e_1 = constant_op.constant(5.0);
  231. with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  232. {
  233. b_2 = tf.multiply(a_3, a_4);
  234. c_2 = tf.multiply(a_1, b_1.output);
  235. d_2 = tf.multiply(b_2.output, c_2.output);
  236. e_2 = tf.multiply(e_1.output, e_1.output);
  237. with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
  238. {
  239. b_3 = tf.multiply(a_3, a_4);
  240. c_3 = tf.multiply(a_1, b_1.output);
  241. d_3 = tf.multiply(b_3.output, c_3.output);
  242. e_3 = tf.multiply(e_2.output, e_2.output);
  243. with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
  244. {
  245. b_4 = tf.multiply(a_3, a_4);
  246. c_4 = tf.multiply(a_1, b_1.output);
  247. d_4 = tf.multiply(b_4.output, c_4.output);
  248. e_4 = tf.multiply(e_3.output, e_3.output);
  249. });
  250. });
  251. });
  252. });
  253. // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
  254. assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
  255. assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
  256. assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);
  257. assertItemsEqual(new[] {a_1.op, a_2.op}, b_4.op.control_inputs);
  258. assertItemsEqual(new object[0], c_1.op.control_inputs);
  259. assertItemsEqual(new[] {a_2.op}, c_2.op.control_inputs);
  260. assertItemsEqual(new[] {a_2.op, a_3.op}, c_3.op.control_inputs);
  261. assertItemsEqual(new[] {a_2.op, a_3.op, a_4.op}, c_4.op.control_inputs);
  262. assertItemsEqual(new object[0], d_1.op.control_inputs);
  263. assertItemsEqual(new object[0], d_2.op.control_inputs);
  264. assertItemsEqual(new object[0], d_3.op.control_inputs);
  265. assertItemsEqual(new object[0], d_4.op.control_inputs);
  266. assertItemsEqual(new[] {a_1.op}, e_1.op.control_inputs);
  267. assertItemsEqual(new[] {a_2.op}, e_2.op.control_inputs);
  268. assertItemsEqual(new[] {a_3.op}, e_3.op.control_inputs);
  269. assertItemsEqual(new[] {a_4.op}, e_4.op.control_inputs);
  270. }
  271. [Ignore("Don't know how to create an operation with two outputs")]
  272. [TestMethod]
  273. public void TestRepeatedDependency()
  274. {
  275. /*
  276. def testRepeatedDependency(self):
  277. g = ops.Graph()
  278. a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
  279. a_0, a_1 = a.outputs
  280. with g.control_dependencies([a_0]):
  281. b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  282. with g.control_dependencies([a_1]):
  283. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  284. self.assertEqual(b.op.control_inputs, [a])
  285. self.assertEqual(c.op.control_inputs, [a])
  286. */
  287. }
  288. [TestMethod]
  289. public void TestNoControlDependencyWithDataDependency()
  290. {
  291. var g = tf.Graph().as_default();
  292. Operation b = null;
  293. var a = constant_op.constant(100.0);
  294. with(g.control_dependencies(new[] { a }), ctrl1 =>
  295. {
  296. b = array_ops.identity(a);
  297. });
  298. Assert.AreEqual(0, b.op.control_inputs.Length);
  299. }
  300. }
  301. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。