You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using static Tensorflow.Binding;
  4. namespace TensorFlowNET.UnitTest.ControlFlowTest
  5. {
  6. /// <summary>
  7. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  8. /// </summary>
  9. [TestClass]
  10. public class CondTestCases : GraphModeTestBase
  11. {
  12. [Ignore("Dependent on UpdateEdge")]
  13. [TestMethod]
  14. public void testCondTrue_ConstOnly()
  15. {
  16. var graph = tf.Graph().as_default();
  17. var sess = tf.Session(graph);
  18. var x = tf.constant(2, name: "x");
  19. var y = tf.constant(5, name: "y");
  20. var z = control_flow_ops.cond(tf.less(x, y),
  21. () => tf.constant(22, name: "t22"),
  22. () => tf.constant(55, name: "f55"));
  23. int result = z.eval(sess);
  24. assertEquals(result, 22);
  25. }
  26. [TestMethod]
  27. public void testCondFalse_ConstOnly()
  28. {
  29. var graph = tf.Graph().as_default();
  30. var sess = tf.Session(graph);
  31. var x = tf.constant(2, name: "x");
  32. var y = tf.constant(1, name: "y");
  33. var z = control_flow_ops.cond(tf.less(x, y),
  34. () => tf.constant(22, name: "t22"),
  35. () => tf.constant(11, name: "f11"));
  36. int result = z.eval(sess);
  37. assertEquals(result, 11);
  38. }
  39. [Ignore("Dependent on UpdateEdge")]
  40. [TestMethod]
  41. public void testCondTrue()
  42. {
  43. tf.Graph().as_default();
  44. var x = tf.constant(2, name: "x");
  45. var y = tf.constant(5, name: "y");
  46. var z = control_flow_ops.cond(tf.less(x, y),
  47. () => tf.multiply(x, 17),
  48. () => tf.add(y, 23));
  49. var result = evaluate<int>(z);
  50. assertEquals(result, 34);
  51. }
  52. [Ignore("Dependent on UpdateEdge")]
  53. [TestMethod]
  54. public void testCondFalse()
  55. {
  56. tf.Graph().as_default();
  57. var x = tf.constant(2);
  58. var y = tf.constant(1);
  59. var z = control_flow_ops.cond(tf.less(x, y),
  60. () => tf.multiply(x, 17),
  61. () => tf.add(y, 23));
  62. var result = evaluate<int>(z);
  63. assertEquals(result, 24);
  64. }
  65. // NOTE: all other python test cases of this class are either not needed due to strong typing or test a deprecated api
  66. }
  67. }