You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlDependenciesTest.cs 14 kB


  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Microsoft.VisualStudio.TestTools.UnitTesting;
  6. using Tensorflow;
  7. using Tensorflow.Eager;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. /// <summary>
  11. /// excerpt of tensorflow/python/framework/ops_test.py
  12. /// </summary>
  13. [TestClass]
  14. public class ControlDependenciesTest : PythonTest
  15. {
  16. [TestMethod]
  17. public void TestBasic()
  18. {
  19. var graph = tf.Graph().as_default();
  20. Tensor a = null, b = null, c = null, d = null, e = null;
  21. with<Graph>(graph, g =>
  22. {
  23. a = constant_op.constant(1.0);
  24. b = constant_op.constant(1.0);
  25. with(g.control_dependencies(new[] { a }), x =>
  26. {
  27. c = constant_op.constant(1.0);
  28. d = array_ops.identity(b);
  29. e = array_ops.identity(c);
  30. });
  31. });
  32. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op }));
  33. Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op }));
  34. // e should be dominated by c.
  35. Assert.AreEqual(0, e.op.control_inputs.Length);
  36. }
  37. [Ignore("Future is not supported yet")]
  38. [TestMethod]
  39. public void TestEager()
  40. {
  41. Tensor a = null, c = null, d = null, e = null;
  42. object b = null;
  43. var calls = 0;
  44. Func<Tensor> future = () =>
  45. {
  46. calls += 1;
  47. return constant_op.constant(2.0);
  48. };
  49. using (var opts = new ContextOptions())
  50. using (var status = new Status())
  51. using (var context = new Context(opts, status))
  52. {
  53. if (context.executing_eagerly())
  54. {
  55. // TODO: make this compile (see original Python code below)
  56. a = constant_op.constant(1.0);
  57. b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
  58. with(ops.control_dependencies(new object[] { a, b }), ctrl =>
  59. {
  60. return c = constant_op.constant(3.0);
  61. });
  62. Assert.AreEqual(calls, 1);
  63. }
  64. else
  65. {
  66. var graph = tf.Graph().as_default();
  67. with<Graph>(graph, g =>
  68. {
  69. a = constant_op.constant(1.0);
  70. var b1 = future();
  71. with(g.control_dependencies(new [] { a, b}), ctrl =>
  72. {
  73. c = constant_op.constant(3.0);
  74. });
  75. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b1.op }));
  76. Assert.AreEqual(1, calls);
  77. });
  78. }
  79. }
  80. /*
  81. def testEager(self):
  82. def future():
  83. future.calls += 1
  84. return constant_op.constant(2.0)
  85. future.calls = 0
  86. if context.executing_eagerly():
  87. a = constant_op.constant(1.0)
  88. b = future
  89. with ops.control_dependencies([a, b]):
  90. c = constant_op.constant(3.0)
  91. self.assertEqual(future.calls, 1)
  92. else:
  93. g = ops.Graph()
  94. with g.as_default():
  95. a = constant_op.constant(1.0)
  96. b = future()
  97. with g.control_dependencies([a, b]):
  98. c = constant_op.constant(3.0)
  99. self.assertEqual(c.op.control_inputs, [a.op, b.op])
  100. self.assertEqual(future.calls, 1)
  101. */
  102. }
  103. [Ignore("How to port the ConvertibleObj?")]
  104. [TestMethod]
  105. public void TestBasicWithConversion()
  106. {
  107. var g = tf.Graph().as_default();
  108. // Note: _apply_op can be replaced by g.create_op
  109. var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  110. // TODO: ConvertibleObj, see original source below
  111. /*
  112. def testBasicWithConversion(self):
  113. g = ops.Graph()
  114. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  115. class ConvertibleObj(object):
  116. def _as_graph_element(self):
  117. return a
  118. with g.control_dependencies([ConvertibleObj()]):
  119. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  120. self.assertEqual(c.op.control_inputs, [a.op])
  121. */
  122. }
  123. [TestMethod]
  124. public void TestNested()
  125. {
  126. var g = tf.Graph().as_default();
  127. var a_1 = constant_op.constant(1.0);
  128. var a_2 = constant_op.constant(3.0);
  129. var a_3 = constant_op.constant(4.0);
  130. var a_4 = constant_op.constant(5.0);
  131. Operation b_1 = null, b_2 = null;
  132. with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
  133. {
  134. b_1 = constant_op.constant(6.0);
  135. });
  136. with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  137. {
  138. with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  139. {
  140. with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
  141. {
  142. with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
  143. {
  144. b_2 = constant_op.constant(7.0);
  145. });
  146. });
  147. });
  148. });
  149. AssertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
  150. AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
  151. }
  152. [Ignore("Fails")]
  153. [TestMethod]
  154. public void TestClear()
  155. {
  156. var g = tf.Graph().as_default();
  157. var a_1 = constant_op.constant(1.0);
  158. var a_2 = constant_op.constant(3.0);
  159. var a_3 = constant_op.constant(4.0);
  160. var a_4 = constant_op.constant(5.0);
  161. Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null;
  162. with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
  163. {
  164. with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
  165. {
  166. with(g.control_dependencies(null), ctrl3 =>
  167. {
  168. with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
  169. {
  170. with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
  171. {
  172. // deps [a_3, a_4]
  173. b_3_4 = constant_op.constant(7.0);
  174. });
  175. // deps = [a_3]
  176. b_3 = constant_op.constant(8.0);
  177. });
  178. // deps back to None
  179. b_none = constant_op.constant(9.0);
  180. });
  181. // deps back to [a_1, a_2]
  182. b_1_2 = constant_op.constant(10.0);
  183. });
  184. // deps back to [a_1]
  185. b_1 = constant_op.constant(11.0);
  186. with(g.control_dependencies(null), ctrl6 =>
  187. {
  188. // deps are None again
  189. b_none2 = constant_op.constant(12.0);
  190. });
  191. });
  192. AssertItemsEqual(new[] {a_3.op, a_4.op}, b_3_4.op.control_inputs);
  193. AssertItemsEqual(new[] {a_3.op}, b_3.op.control_inputs);
  194. AssertItemsEqual(new object[0], b_none.op.control_inputs);
  195. AssertItemsEqual(new[] {a_1.op, a_2.op}, b_1_2.op.control_inputs);
  196. AssertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
  197. AssertItemsEqual(new object[0], b_none2.op.control_inputs);
  198. /*
  199. def testClear(self):
  200. g = ops.Graph()
  201. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  202. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  203. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  204. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  205. with g.control_dependencies([a_1]):
  206. with g.control_dependencies([a_2]):
  207. with g.control_dependencies(None):
  208. with g.control_dependencies([a_3]):
  209. with g.control_dependencies([a_4]):
  210. # deps [a_3, a_4]
  211. b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  212. # deps = [a_3]
  213. b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  214. # deps back to None
  215. b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  216. # deps back to [a_1, a_2]
  217. b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  218. # deps back to [a_1]
  219. b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  220. with g.control_dependencies(None):
  221. # deps are None again
  222. b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  223. self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
  224. self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
  225. self.assertItemsEqual([], b_none.op.control_inputs)
  226. self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
  227. self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
  228. self.assertItemsEqual([], b_none2.op.control_inputs)
  229. */
  230. }
  231. [Ignore("will fail due to unsupported op 'FloatOutput'")]
  232. [TestMethod]
  233. public void TestComplex()
  234. {
  235. /*
  236. def testComplex(self):
  237. g = ops.Graph()
  238. # Usage pattern:
  239. # * Nodes a_i are constants defined at the outermost scope, and are used
  240. # as control inputs for the ith nested scope.
  241. # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
  242. # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
  243. # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
  244. # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
  245. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  246. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  247. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  248. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  249. with g.control_dependencies([a_1]):
  250. b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  251. [dtypes.float32])
  252. c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  253. [dtypes.float32])
  254. d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
  255. [dtypes.float32])
  256. e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  257. with g.control_dependencies([a_2]):
  258. b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  259. [dtypes.float32])
  260. c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  261. [dtypes.float32])
  262. d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
  263. [dtypes.float32])
  264. e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
  265. [dtypes.float32])
  266. with g.control_dependencies([a_3]):
  267. b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  268. [dtypes.float32])
  269. c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  270. [dtypes.float32])
  271. d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
  272. [dtypes.float32])
  273. e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
  274. [dtypes.float32])
  275. with g.control_dependencies([a_4]):
  276. b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  277. [dtypes.float32])
  278. c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  279. [dtypes.float32])
  280. d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
  281. [dtypes.float32])
  282. e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
  283. [dtypes.float32])
  284. self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
  285. self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
  286. self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
  287. self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
  288. self.assertItemsEqual([], c_1.op.control_inputs)
  289. self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
  290. self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
  291. self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
  292. self.assertItemsEqual([], d_1.op.control_inputs)
  293. self.assertItemsEqual([], d_2.op.control_inputs)
  294. self.assertItemsEqual([], d_3.op.control_inputs)
  295. self.assertItemsEqual([], d_4.op.control_inputs)
  296. self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
  297. self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
  298. self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
  299. self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
  300. */
  301. }
  302. [Ignore("will fail due to unsupported op 'FloatOutput'")]
  303. [TestMethod]
  304. public void TestRepeatedDependency()
  305. {
  306. /*
  307. def testRepeatedDependency(self):
  308. g = ops.Graph()
  309. a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
  310. a_0, a_1 = a.outputs
  311. with g.control_dependencies([a_0]):
  312. b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  313. with g.control_dependencies([a_1]):
  314. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  315. self.assertEqual(b.op.control_inputs, [a])
  316. self.assertEqual(c.op.control_inputs, [a])
  317. def testNoControlDependencyWithDataDependency(self):
  318. g = ops.Graph()
  319. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  320. with g.control_dependencies([a]):
  321. b = _apply_op(g, "Identity", [a], [dtypes.float32])
  322. self.assertEqual(b.op.control_inputs, [])
  323. */
  324. }
  325. }
  326. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。