You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 2.8 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Newtonsoft.Json;
  3. using System;
  4. using Tensorflow;
  5. using static Tensorflow.Python;
  6. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  7. {
  8. /// <summary>
  9. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  10. /// </summary>
  11. [TestClass]
  12. public class CondTestCases : PythonTest
  13. {
  14. [Ignore("need tesnroflow expose AddControlInput API")]
  15. [TestMethod]
  16. public void testCondTrue_ConstOnly()
  17. {
  18. var graph = tf.Graph().as_default();
  19. with(tf.Session(graph), sess =>
  20. {
  21. var x = tf.constant(2, name: "x");
  22. var y = tf.constant(5, name: "y");
  23. var z = control_flow_ops.cond(tf.less(x, y),
  24. () => tf.constant(22, name: "t22"),
  25. () => tf.constant(55, name: "f55"));
  26. int result = z.eval(sess);
  27. assertEquals(result, 22);
  28. });
  29. }
  30. [Ignore("need tesnroflow expose AddControlInput API")]
  31. [TestMethod]
  32. public void testCondFalse_ConstOnly()
  33. {
  34. var graph = tf.Graph().as_default();
  35. with(tf.Session(graph), sess =>
  36. {
  37. var x = tf.constant(2, name: "x");
  38. var y = tf.constant(1, name: "y");
  39. var z = control_flow_ops.cond(tf.less(x, y),
  40. () => tf.constant(22, name: "t22"),
  41. () => tf.constant(11, name: "f11"));
  42. int result = z.eval(sess);
  43. assertEquals(result, 11);
  44. });
  45. }
  46. [Ignore("need tesnroflow expose AddControlInput API")]
  47. [TestMethod]
  48. public void testCondTrue()
  49. {
  50. tf.Graph().as_default();
  51. var x = tf.constant(2, name: "x");
  52. var y = tf.constant(5, name: "y");
  53. var z = control_flow_ops.cond(tf.less(x, y),
  54. () => tf.multiply(x, 17),
  55. () => tf.add(y, 23));
  56. var result = evaluate<int>(z);
  57. assertEquals(result, 34);
  58. }
  59. [Ignore("need tesnroflow expose AddControlInput API")]
  60. [TestMethod]
  61. public void testCondFalse()
  62. {
  63. tf.Graph().as_default();
  64. var x = tf.constant(2);
  65. var y = tf.constant(1);
  66. var z = control_flow_ops.cond(tf.less(x, y),
  67. () => tf.multiply(x, 17),
  68. () => tf.add(y, 23));
  69. var result = evaluate<int>(z);
  70. assertEquals(result, 24);
  71. }
  72. // NOTE: all other python test cases of this class are either not needed due to strong typing or test a deprecated api
  73. }
  74. }