You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Newtonsoft.Json;
  3. using System;
  4. using Tensorflow;
  5. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  6. {
  7. /// <summary>
  8. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  9. /// </summary>
  10. [TestClass]
  11. public class CondTestCases : PythonTest
  12. {
  13. [TestMethod]
  14. public void testCondTrue()
  15. {
  16. var graph = tf.Graph().as_default();
  17. // tf.train.import_meta_graph("cond_test.meta");
  18. var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
  19. with(tf.Session(graph), sess =>
  20. {
  21. var x = tf.constant(2, name: "x"); // graph.get_operation_by_name("Const").output;
  22. var y = tf.constant(5, name: "y"); // graph.get_operation_by_name("Const_1").output;
  23. var pred = tf.less(x, y); // graph.get_operation_by_name("Less").output;
  24. Func<ITensorOrOperation> if_true = delegate
  25. {
  26. return tf.constant(2, name: "t2");
  27. };
  28. Func<ITensorOrOperation> if_false = delegate
  29. {
  30. return tf.constant(5, name: "f5");
  31. };
  32. var z = control_flow_ops.cond(pred, if_true, if_false); // graph.get_operation_by_name("cond/Merge").output
  33. json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
  34. int result = z.eval(sess);
  35. assertEquals(result, 2);
  36. });
  37. }
  38. [TestMethod]
  39. public void testCondFalse()
  40. {
  41. /* python
  42. * import tensorflow as tf
  43. from tensorflow.python.framework import ops
  44. def if_true():
  45. return tf.math.multiply(x, 17)
  46. def if_false():
  47. return tf.math.add(y, 23)
  48. with tf.Session() as sess:
  49. x = tf.constant(2)
  50. y = tf.constant(1)
  51. pred = tf.math.less(x,y)
  52. z = tf.cond(pred, if_true, if_false)
  53. result = z.eval()
  54. print(result == 24) */
  55. var graph = tf.Graph().as_default();
  56. //tf.train.import_meta_graph("cond_test.meta");
  57. //var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
  58. with(tf.Session(), sess =>
  59. {
  60. var x = tf.constant(2, name: "x");
  61. var y = tf.constant(1, name: "y");
  62. var pred = tf.less(x, y);
  63. Func<ITensorOrOperation> if_true = delegate
  64. {
  65. return tf.constant(2, name: "t2");
  66. };
  67. Func<ITensorOrOperation> if_false = delegate
  68. {
  69. return tf.constant(1, name: "f1");
  70. };
  71. var z = control_flow_ops.cond(pred, if_true, if_false);
  72. var json1 = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
  73. int result = z.eval(sess);
  74. assertEquals(result, 1);
  75. });
  76. }
  77. [Ignore("Todo")]
  78. [TestMethod]
  79. public void testCondTrueLegacy()
  80. {
  81. // def testCondTrueLegacy(self):
  82. // x = constant_op.constant(2)
  83. // y = constant_op.constant(5)
  84. // z = control_flow_ops.cond(
  85. // math_ops.less(x, y),
  86. // fn1=lambda: math_ops.multiply(x, 17),
  87. // fn2=lambda: math_ops.add(y, 23))
  88. // self.assertEquals(self.evaluate(z), 34)
  89. }
  90. [Ignore("Todo")]
  91. [TestMethod]
  92. public void testCondFalseLegacy()
  93. {
  94. // def testCondFalseLegacy(self):
  95. // x = constant_op.constant(2)
  96. // y = constant_op.constant(1)
  97. // z = control_flow_ops.cond(
  98. // math_ops.less(x, y),
  99. // fn1=lambda: math_ops.multiply(x, 17),
  100. // fn2=lambda: math_ops.add(y, 23))
  101. // self.assertEquals(self.evaluate(z), 24)
  102. }
  103. [Ignore("Todo")]
  104. [TestMethod]
  105. public void testCondMissingArg1()
  106. {
  107. // def testCondMissingArg1(self):
  108. // x = constant_op.constant(1)
  109. // with self.assertRaises(TypeError):
  110. // control_flow_ops.cond(True, false_fn=lambda: x)
  111. }
  112. [Ignore("Todo")]
  113. [TestMethod]
  114. public void testCondMissingArg2()
  115. {
  116. // def testCondMissingArg2(self):
  117. // x = constant_op.constant(1)
  118. // with self.assertRaises(TypeError):
  119. // control_flow_ops.cond(True, lambda: x)
  120. }
  121. [Ignore("Todo")]
  122. [TestMethod]
  123. public void testCondDuplicateArg1()
  124. {
  125. // def testCondDuplicateArg1(self):
  126. // x = constant_op.constant(1)
  127. // with self.assertRaises(TypeError):
  128. // control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
  129. }
  130. [Ignore("Todo")]
  131. [TestMethod]
  132. public void testCondDuplicateArg2()
  133. {
  134. // def testCondDuplicateArg2(self):
  135. // x = constant_op.constant(1)
  136. // with self.assertRaises(TypeError):
  137. // control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
  138. }
  139. }
  140. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。