You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  5. {
  6. /// <summary>
  7. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  8. /// </summary>
  9. [TestClass]
  10. public class CondTestCases : PythonTest
  11. {
  12. [TestMethod]
  13. public void testCondTrue()
  14. {
  15. with(tf.Session(), sess =>
  16. {
  17. var x = tf.constant(2);
  18. var y = tf.constant(5);
  19. var z = control_flow_ops.cond(tf.less(x, y),
  20. () => tf.multiply(x, 17),
  21. () => tf.add(y, 23));
  22. int result = z.eval(sess);
  23. assertEquals(result, 34);
  24. });
  25. }
  26. [TestMethod]
  27. public void testCondFalse()
  28. {
  29. /* python
  30. * import tensorflow as tf
  31. from tensorflow.python.framework import ops
  32. def if_true():
  33. return tf.math.multiply(x, 17)
  34. def if_false():
  35. return tf.math.add(y, 23)
  36. with tf.Session() as sess:
  37. x = tf.constant(2)
  38. y = tf.constant(1)
  39. pred = tf.math.less(x,y)
  40. z = tf.cond(pred, if_true, if_false)
  41. result = z.eval()
  42. print(result == 24) */
  43. with(tf.Session(), sess =>
  44. {
  45. var x = tf.constant(2);
  46. var y = tf.constant(1);
  47. var pred = tf.less(x, y);
  48. Func<ITensorOrOperation> if_true = delegate
  49. {
  50. return tf.multiply(x, 17);
  51. };
  52. Func<ITensorOrOperation> if_false = delegate
  53. {
  54. return tf.add(y, 23);
  55. };
  56. var z = control_flow_ops.cond(pred, if_true, if_false);
  57. int result = z.eval(sess);
  58. assertEquals(result, 24);
  59. });
  60. }
  61. [Ignore("Todo")]
  62. [TestMethod]
  63. public void testCondTrueLegacy()
  64. {
  65. // def testCondTrueLegacy(self):
  66. // x = constant_op.constant(2)
  67. // y = constant_op.constant(5)
  68. // z = control_flow_ops.cond(
  69. // math_ops.less(x, y),
  70. // fn1=lambda: math_ops.multiply(x, 17),
  71. // fn2=lambda: math_ops.add(y, 23))
  72. // self.assertEquals(self.evaluate(z), 34)
  73. }
  74. [Ignore("Todo")]
  75. [TestMethod]
  76. public void testCondFalseLegacy()
  77. {
  78. // def testCondFalseLegacy(self):
  79. // x = constant_op.constant(2)
  80. // y = constant_op.constant(1)
  81. // z = control_flow_ops.cond(
  82. // math_ops.less(x, y),
  83. // fn1=lambda: math_ops.multiply(x, 17),
  84. // fn2=lambda: math_ops.add(y, 23))
  85. // self.assertEquals(self.evaluate(z), 24)
  86. }
  87. [Ignore("Todo")]
  88. [TestMethod]
  89. public void testCondMissingArg1()
  90. {
  91. // def testCondMissingArg1(self):
  92. // x = constant_op.constant(1)
  93. // with self.assertRaises(TypeError):
  94. // control_flow_ops.cond(True, false_fn=lambda: x)
  95. }
  96. [Ignore("Todo")]
  97. [TestMethod]
  98. public void testCondMissingArg2()
  99. {
  100. // def testCondMissingArg2(self):
  101. // x = constant_op.constant(1)
  102. // with self.assertRaises(TypeError):
  103. // control_flow_ops.cond(True, lambda: x)
  104. }
  105. [Ignore("Todo")]
  106. [TestMethod]
  107. public void testCondDuplicateArg1()
  108. {
  109. // def testCondDuplicateArg1(self):
  110. // x = constant_op.constant(1)
  111. // with self.assertRaises(TypeError):
  112. // control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
  113. }
  114. [Ignore("Todo")]
  115. [TestMethod]
  116. public void testCondDuplicateArg2()
  117. {
  118. // def testCondDuplicateArg2(self):
  119. // x = constant_op.constant(1)
  120. // with self.assertRaises(TypeError):
  121. // control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
  122. }
  123. }
  124. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。