You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SwitchTestCase.cs 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. using System;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using Tensorflow;
  4. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  5. {
  6. /// <summary>
  7. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  8. /// </summary>
  9. [TestClass]
  10. public class SwitchTestCase : PythonTest
  11. {
  12. [Ignore("TODO")]
  13. [TestMethod]
  14. public void testResourceReadInLoop()
  15. {
  16. var embedding_matrix = variable_scope.get_variable(
  17. "embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true);
  18. Tensor cond(Tensor it, Tensor _)
  19. {
  20. return it < 5;
  21. }
  22. // TODO: below code doesn't compile
  23. //(Tensor, Tensor) body(Tensor it, Tensor cost)
  24. //{
  25. // var embedding = embedding_ops.embedding_lookup(embedding_matrix, new int[]{0});
  26. // cost += math_ops.reduce_sum(embedding);
  27. // return (it + 1, cost);
  28. //}
  29. //var (_, cost1) = control_flow_ops.while_loop(
  30. // cond, body, new[]
  31. // {
  32. // constant_op.constant(0),
  33. // constant_op.constant(0.0)
  34. // });
  35. //with<Session>(this.cached_session(), sess =>
  36. //{
  37. // self.evaluate(variables.global_variables_initializer());
  38. // self.assertAllEqual(10.0, self.evaluate(cost1));
  39. //});
  40. }
  41. [Ignore("TODO")]
  42. [TestMethod]
  43. public void testIndexedSlicesGradientInCondInWhileLoop()
  44. {
  45. doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: false);
  46. }
  47. [Ignore("TODO")]
  48. [TestMethod]
  49. public void testIndexedSlicesGradientInCondInWhileLoopResource()
  50. {
  51. doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true);
  52. }
  53. private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource = false)
  54. {
  55. //def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
  56. // embedding_matrix = variable_scope.get_variable(
  57. // "embedding_matrix", [5, 5],
  58. // initializer=init_ops.random_normal_initializer(),
  59. // use_resource=use_resource)
  60. // def cond(it, _):
  61. // return it < 5
  62. // def body(it, cost):
  63. // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
  64. // cost = control_flow_ops.cond(
  65. // math_ops.equal(it, 3), lambda: math_ops.square(cost),
  66. // (lambda: cost + math_ops.reduce_sum(embedding)))
  67. // return it + 1, cost
  68. // _, cost = control_flow_ops.while_loop(
  69. // cond, body, [constant_op.constant(0),
  70. // constant_op.constant(0.0)])
  71. // dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
  72. // dynamic_grads = math_ops.segment_sum(dynamic_grads.values,
  73. // dynamic_grads.indices)
  74. // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
  75. // static = math_ops.square(
  76. // math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
  77. // math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
  78. // static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
  79. // static_grads = math_ops.segment_sum(static_grads.values,
  80. // static_grads.indices)
  81. // with self.cached_session():
  82. // self.evaluate(variables.global_variables_initializer())
  83. // self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads]))
  84. }
  85. [Ignore("TODO")]
  86. [TestMethod]
  87. public void testIndexedSlicesWithShapeGradientInWhileLoop()
  88. {
  89. //@test_util.run_v1_only("b/120545219")
  90. //def testIndexedSlicesWithShapeGradientInWhileLoop(self):
  91. // for dtype in [dtypes.float32, dtypes.float64]:
  92. // with self.cached_session() as sess:
  93. // num_steps = 9
  94. // inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
  95. // initial_outputs = tensor_array_ops.TensorArray(
  96. // dtype=dtype, size=num_steps)
  97. // initial_i = constant_op.constant(0, dtype=dtypes.int32)
  98. // def cond(i, _):
  99. // return i < num_steps # pylint: disable=cell-var-from-loop
  100. // def body(i, outputs):
  101. // x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
  102. // outputs = outputs.write(i, x)
  103. // return i + 1, outputs
  104. // _, outputs = control_flow_ops.while_loop(cond, body,
  105. // [initial_i, initial_outputs])
  106. // outputs = math_ops.reduce_sum(outputs.stack())
  107. // r = gradients_impl.gradients([outputs], [inputs])[0]
  108. // grad_wr_inputs = ops.convert_to_tensor(r)
  109. // o, grad = sess.run([outputs, grad_wr_inputs],
  110. // feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
  111. // self.assertEquals(o, 20)
  112. // self.assertAllEqual(grad, [1] * num_steps)
  113. }
  114. [Ignore("TODO")]
  115. [TestMethod]
  116. public void testIndexedSlicesWithDynamicShapeGradientInWhileLoop()
  117. {
  118. //@test_util.run_v1_only("b/120545219")
  119. //def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
  120. // for dtype in [dtypes.float32, dtypes.float64]:
  121. // with self.cached_session() as sess:
  122. // inputs = array_ops.placeholder(dtype=dtype)
  123. // initial_outputs = tensor_array_ops.TensorArray(
  124. // dtype=dtype, dynamic_size=True, size=1)
  125. // initial_i = constant_op.constant(0, dtype=dtypes.int32)
  126. // def cond(i, _):
  127. // return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop
  128. // def body(i, outputs):
  129. // x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
  130. // outputs = outputs.write(i, x)
  131. // return i + 1, outputs
  132. // _, outputs = control_flow_ops.while_loop(cond, body,
  133. // [initial_i, initial_outputs])
  134. // outputs = math_ops.reduce_sum(outputs.stack())
  135. // r = gradients_impl.gradients([outputs], [inputs])[0]
  136. // grad_wr_inputs = ops.convert_to_tensor(r)
  137. // o, grad = sess.run([outputs, grad_wr_inputs],
  138. // feed_dict={inputs: [1, 3, 2]})
  139. // self.assertEquals(o, 6)
  140. // self.assertAllEqual(grad, [1] * 3)
  141. }
  142. }
  143. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。