You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlDependenciesTest.cs 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Microsoft.VisualStudio.TestTools.UnitTesting;
  6. using Tensorflow;
  7. using Tensorflow.Eager;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. /// <summary>
  11. /// excerpt of tensorflow/python/framework/ops_test.py
  12. /// </summary>
  13. [TestClass]
  14. public class ControlDependenciesTest : PythonTest
  15. {
  16. [TestMethod]
  17. public void TestBasic()
  18. {
  19. var graph = tf.Graph().as_default();
  20. Tensor a = null, b = null, c = null, d = null, e = null;
  21. with<Graph>(graph, g =>
  22. {
  23. a = constant_op.constant(1.0);
  24. b = constant_op.constant(1.0);
  25. with(g.control_dependencies(new ITensorOrOperation[] { a }), x =>
  26. {
  27. c = constant_op.constant(1.0);
  28. d = array_ops.identity(b);
  29. e = array_ops.identity(c);
  30. });
  31. });
  32. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op }));
  33. Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op }));
  34. // e should be dominated by c.
  35. Assert.AreEqual(0, e.op.control_inputs.Length);
  36. }
  37. [Ignore("Part of this test is not compiling")]
  38. [TestMethod]
  39. public void TestEager()
  40. {
  41. Tensor a = null, b = null, c = null, d = null, e = null;
  42. var calls = 0;
  43. Func<Tensor> future = () =>
  44. {
  45. calls += 1;
  46. return constant_op.constant(2.0);
  47. };
  48. using (var opts = new ContextOptions())
  49. using (var status = new Status())
  50. using (var context = new Context(opts, status))
  51. {
  52. if (context.executing_eagerly())
  53. {
  54. // TODO: make this compile (see original Python code below)
  55. //a = constant_op.constant(1.0);
  56. //b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
  57. //with(ops.control_dependencies(new Operation[] {a, b}), ctrl =>
  58. //{
  59. // return c = constant_op.constant(3.0);
  60. //});
  61. //Assert.AreEqual(calls, 1);
  62. }
  63. else
  64. {
  65. var graph = tf.Graph();
  66. with<Graph>(graph.as_default(), g =>
  67. {
  68. a = constant_op.constant(1.0);
  69. b = future();
  70. with(g.control_dependencies(new ITensorOrOperation[] { a, b }), ctrl =>
  71. {
  72. c = constant_op.constant(3.0);
  73. });
  74. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b.op }));
  75. Assert.AreEqual(1, calls);
  76. });
  77. }
  78. }
  79. /*
  80. def testEager(self):
  81. def future():
  82. future.calls += 1
  83. return constant_op.constant(2.0)
  84. future.calls = 0
  85. if context.executing_eagerly():
  86. a = constant_op.constant(1.0)
  87. b = future
  88. with ops.control_dependencies([a, b]):
  89. c = constant_op.constant(3.0)
  90. self.assertEqual(future.calls, 1)
  91. else:
  92. g = ops.Graph()
  93. with g.as_default():
  94. a = constant_op.constant(1.0)
  95. b = future()
  96. with g.control_dependencies([a, b]):
  97. c = constant_op.constant(3.0)
  98. self.assertEqual(c.op.control_inputs, [a.op, b.op])
  99. self.assertEqual(future.calls, 1)
  100. */
  101. }
  102. // Note: {henon}, all tests below use the function _apply_op which is not really portable in C#, see original source below
  103. // but I think _apply_op(...) can just be replaced by g.create_op(...).
  104. /*
  105. def _apply_op(g, *args, **kwargs):
  106. op = g.create_op(*args, **kwargs)
  107. if len(op.outputs) == 1:
  108. return op.outputs[0]
  109. else:
  110. return op.outputs
  111. */
  112. [Ignore("")]
  113. [TestMethod]
  114. public void TestBasicWithConversion()
  115. {
  116. var g = ops.get_default_graph();
  117. // Note: _apply_op can be replaced by g.create_op
  118. var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  119. // TODO: ConvertibleObj, see original source below
  120. /*
  121. def testBasicWithConversion(self):
  122. g = ops.Graph()
  123. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  124. class ConvertibleObj(object):
  125. def _as_graph_element(self):
  126. return a
  127. with g.control_dependencies([ConvertibleObj()]):
  128. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  129. self.assertEqual(c.op.control_inputs, [a.op])
  130. */
  131. }
  132. [Ignore("Fails with message: Op type not registered 'FloatOutput' in binary running on ...")]
  133. [TestMethod]
  134. public void TestNested()
  135. {
  136. var g = ops.get_default_graph();
  137. var a_1 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  138. var a_2 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  139. var a_3 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  140. var a_4 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  141. Operation b_1 = null, b_2 = null;
  142. with(g.control_dependencies(new ITensorOrOperation[] { a_1, a_2, a_3, a_4 }), ctrl =>
  143. {
  144. b_1 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  145. });
  146. with(g.control_dependencies(new ITensorOrOperation[] { a_1 }), ctrl1 =>
  147. {
  148. with(g.control_dependencies(new ITensorOrOperation[] { a_2 }), ctrl2 =>
  149. {
  150. with(g.control_dependencies(new ITensorOrOperation[] { a_3 }), ctrl3 =>
  151. {
  152. with(g.control_dependencies(new ITensorOrOperation[] { a_4 }), ctrl4 =>
  153. {
  154. b_2 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
  155. });
  156. });
  157. });
  158. });
  159. AssertItemsEqual(new[] {a_1.op, a_2.op, a_3.op, a_4.op}, b_1.op.control_inputs);
  160. AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
  161. /*
  162. def testNested(self):
  163. g = ops.Graph()
  164. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  165. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  166. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  167. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  168. with g.control_dependencies([a_1, a_2, a_3, a_4]):
  169. b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  170. with g.control_dependencies([a_1]):
  171. with g.control_dependencies([a_2]):
  172. with g.control_dependencies([a_3]):
  173. with g.control_dependencies([a_4]):
  174. b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  175. self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
  176. b_1.op.control_inputs)
  177. self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
  178. */
  179. }
  180. [Ignore("will fail due to unsupported op 'FloatOutput'")]
  181. [TestMethod]
  182. public void TestClear()
  183. {
  184. /*
  185. def testClear(self):
  186. g = ops.Graph()
  187. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  188. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  189. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  190. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  191. with g.control_dependencies([a_1]):
  192. with g.control_dependencies([a_2]):
  193. with g.control_dependencies(None):
  194. with g.control_dependencies([a_3]):
  195. with g.control_dependencies([a_4]):
  196. # deps [a_3, a_4]
  197. b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  198. # deps = [a_3]
  199. b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  200. # deps back to None
  201. b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  202. # deps back to [a_1, a_2]
  203. b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  204. # deps back to [a_1]
  205. b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  206. with g.control_dependencies(None):
  207. # deps are None again
  208. b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  209. self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
  210. self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
  211. self.assertItemsEqual([], b_none.op.control_inputs)
  212. self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
  213. self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
  214. self.assertItemsEqual([], b_none2.op.control_inputs)
  215. */
  216. }
  217. [Ignore("will fail due to unsupported op 'FloatOutput'")]
  218. [TestMethod]
  219. public void TestComplex()
  220. {
  221. /*
  222. def testComplex(self):
  223. g = ops.Graph()
  224. # Usage pattern:
  225. # * Nodes a_i are constants defined at the outermost scope, and are used
  226. # as control inputs for the ith nested scope.
  227. # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
  228. # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
  229. # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
  230. # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
  231. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  232. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  233. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  234. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  235. with g.control_dependencies([a_1]):
  236. b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  237. [dtypes.float32])
  238. c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  239. [dtypes.float32])
  240. d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
  241. [dtypes.float32])
  242. e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  243. with g.control_dependencies([a_2]):
  244. b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  245. [dtypes.float32])
  246. c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  247. [dtypes.float32])
  248. d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
  249. [dtypes.float32])
  250. e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
  251. [dtypes.float32])
  252. with g.control_dependencies([a_3]):
  253. b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  254. [dtypes.float32])
  255. c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  256. [dtypes.float32])
  257. d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
  258. [dtypes.float32])
  259. e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
  260. [dtypes.float32])
  261. with g.control_dependencies([a_4]):
  262. b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  263. [dtypes.float32])
  264. c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  265. [dtypes.float32])
  266. d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
  267. [dtypes.float32])
  268. e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
  269. [dtypes.float32])
  270. self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
  271. self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
  272. self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
  273. self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
  274. self.assertItemsEqual([], c_1.op.control_inputs)
  275. self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
  276. self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
  277. self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
  278. self.assertItemsEqual([], d_1.op.control_inputs)
  279. self.assertItemsEqual([], d_2.op.control_inputs)
  280. self.assertItemsEqual([], d_3.op.control_inputs)
  281. self.assertItemsEqual([], d_4.op.control_inputs)
  282. self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
  283. self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
  284. self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
  285. self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
  286. */
  287. }
  288. [Ignore("will fail due to unsupported op 'FloatOutput'")]
  289. [TestMethod]
  290. public void TestRepeatedDependency()
  291. {
  292. /*
  293. def testRepeatedDependency(self):
  294. g = ops.Graph()
  295. a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
  296. a_0, a_1 = a.outputs
  297. with g.control_dependencies([a_0]):
  298. b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  299. with g.control_dependencies([a_1]):
  300. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  301. self.assertEqual(b.op.control_inputs, [a])
  302. self.assertEqual(c.op.control_inputs, [a])
  303. def testNoControlDependencyWithDataDependency(self):
  304. g = ops.Graph()
  305. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  306. with g.control_dependencies([a]):
  307. b = _apply_op(g, "Identity", [a], [dtypes.float32])
  308. self.assertEqual(b.op.control_inputs, [])
  309. */
  310. }
  311. }
  312. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。