You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using static Tensorflow.Binding;
  4. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  5. {
  6. /// <summary>
  7. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  8. /// </summary>
  9. [TestClass]
  10. public class CondTestCases : PythonTest
  11. {
  12. [TestMethod]
  13. public void testCondTrue_ConstOnly()
  14. {
  15. var graph = tf.Graph().as_default();
  16. using (var sess = tf.Session(graph))
  17. {
  18. var x = tf.constant(2, name: "x");
  19. var y = tf.constant(5, name: "y");
  20. var z = control_flow_ops.cond(tf.less(x, y),
  21. () => tf.constant(22, name: "t22"),
  22. () => tf.constant(55, name: "f55"));
  23. int result = z.eval(sess);
  24. assertEquals(result, 22);
  25. }
  26. }
  27. [TestMethod]
  28. public void testCondFalse_ConstOnly()
  29. {
  30. var graph = tf.Graph().as_default();
  31. using (var sess = tf.Session(graph))
  32. {
  33. var x = tf.constant(2, name: "x");
  34. var y = tf.constant(1, name: "y");
  35. var z = control_flow_ops.cond(tf.less(x, y),
  36. () => tf.constant(22, name: "t22"),
  37. () => tf.constant(11, name: "f11"));
  38. int result = z.eval(sess);
  39. assertEquals(result, 11);
  40. }
  41. }
  42. [TestMethod]
  43. public void testCondTrue()
  44. {
  45. tf.Graph().as_default();
  46. var x = tf.constant(2, name: "x");
  47. var y = tf.constant(5, name: "y");
  48. var z = control_flow_ops.cond(tf.less(x, y),
  49. () => tf.multiply(x, 17),
  50. () => tf.add(y, 23));
  51. var result = evaluate<int>(z);
  52. assertEquals(result, 34);
  53. }
  54. [TestMethod]
  55. public void testCondFalse()
  56. {
  57. tf.Graph().as_default();
  58. var x = tf.constant(2);
  59. var y = tf.constant(1);
  60. var z = control_flow_ops.cond(tf.less(x, y),
  61. () => tf.multiply(x, 17),
  62. () => tf.add(y, 23));
  63. var result = evaluate<int>(z);
  64. assertEquals(result, 24);
  65. }
  66. // NOTE: all other python test cases of this class are either not needed due to strong typing or test a deprecated api
  67. }
  68. }