You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using static Tensorflow.Binding;
  4. namespace TensorFlowNET.UnitTest.ControlFlowTest
  5. {
  6. /// <summary>
  7. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  8. /// </summary>
  9. [TestClass]
  10. public class CondTestCases : GraphModeTestBase
  11. {
  12. [Ignore("Dependent on UpdateEdge")]
  13. [TestMethod]
  14. public void testCondTrue_ConstOnly()
  15. {
  16. var graph = tf.Graph().as_default();
  17. using (var sess = tf.Session(graph))
  18. {
  19. var x = tf.constant(2, name: "x");
  20. var y = tf.constant(5, name: "y");
  21. var z = control_flow_ops.cond(tf.less(x, y),
  22. () => tf.constant(22, name: "t22"),
  23. () => tf.constant(55, name: "f55"));
  24. int result = z.eval(sess);
  25. assertEquals(result, 22);
  26. }
  27. }
  28. [TestMethod]
  29. public void testCondFalse_ConstOnly()
  30. {
  31. var graph = tf.Graph().as_default();
  32. using (var sess = tf.Session(graph))
  33. {
  34. var x = tf.constant(2, name: "x");
  35. var y = tf.constant(1, name: "y");
  36. var z = control_flow_ops.cond(tf.less(x, y),
  37. () => tf.constant(22, name: "t22"),
  38. () => tf.constant(11, name: "f11"));
  39. int result = z.eval(sess);
  40. assertEquals(result, 11);
  41. }
  42. }
  43. [Ignore("Dependent on UpdateEdge")]
  44. [TestMethod]
  45. public void testCondTrue()
  46. {
  47. tf.Graph().as_default();
  48. var x = tf.constant(2, name: "x");
  49. var y = tf.constant(5, name: "y");
  50. var z = control_flow_ops.cond(tf.less(x, y),
  51. () => tf.multiply(x, 17),
  52. () => tf.add(y, 23));
  53. var result = evaluate<int>(z);
  54. assertEquals(result, 34);
  55. }
  56. [Ignore("Dependent on UpdateEdge")]
  57. [TestMethod]
  58. public void testCondFalse()
  59. {
  60. tf.Graph().as_default();
  61. var x = tf.constant(2);
  62. var y = tf.constant(1);
  63. var z = control_flow_ops.cond(tf.less(x, y),
  64. () => tf.multiply(x, 17),
  65. () => tf.add(y, 23));
  66. var result = evaluate<int>(z);
  67. assertEquals(result, 24);
  68. }
  69. // NOTE: all other python test cases of this class are either not needed due to strong typing or test a deprecated api
  70. }
  71. }