You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 2.6 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Newtonsoft.Json;
  3. using System;
  4. using Tensorflow;
  5. using static Tensorflow.Python;
  6. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  7. {
  8. /// <summary>
  9. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  10. /// </summary>
  11. [TestClass]
  12. public class CondTestCases : PythonTest
  13. {
  14. [TestMethod]
  15. public void testCondTrue_ConstOnly()
  16. {
  17. var graph = tf.Graph().as_default();
  18. with(tf.Session(graph), sess =>
  19. {
  20. var x = tf.constant(2, name: "x");
  21. var y = tf.constant(5, name: "y");
  22. var z = control_flow_ops.cond(tf.less(x, y),
  23. () => tf.constant(22, name: "t22"),
  24. () => tf.constant(55, name: "f55"));
  25. int result = z.eval(sess);
  26. assertEquals(result, 22);
  27. });
  28. }
  29. [TestMethod]
  30. public void testCondFalse_ConstOnly()
  31. {
  32. var graph = tf.Graph().as_default();
  33. with(tf.Session(graph), sess =>
  34. {
  35. var x = tf.constant(2, name: "x");
  36. var y = tf.constant(1, name: "y");
  37. var z = control_flow_ops.cond(tf.less(x, y),
  38. () => tf.constant(22, name: "t22"),
  39. () => tf.constant(11, name: "f11"));
  40. int result = z.eval(sess);
  41. assertEquals(result, 11);
  42. });
  43. }
  44. [TestMethod]
  45. public void testCondTrue()
  46. {
  47. tf.Graph().as_default();
  48. var x = tf.constant(2, name: "x");
  49. var y = tf.constant(5, name: "y");
  50. var z = control_flow_ops.cond(tf.less(x, y),
  51. () => tf.multiply(x, 17),
  52. () => tf.add(y, 23));
  53. var result = evaluate<int>(z);
  54. assertEquals(result, 34);
  55. }
  56. [TestMethod]
  57. public void testCondFalse()
  58. {
  59. tf.Graph().as_default();
  60. var x = tf.constant(2);
  61. var y = tf.constant(1);
  62. var z = control_flow_ops.cond(tf.less(x, y),
  63. () => tf.multiply(x, 17),
  64. () => tf.add(y, 23));
  65. var result = evaluate<int>(z);
  66. assertEquals(result, 24);
  67. }
  68. // NOTE: all other python test cases of this class are either not needed due to strong typing or test a deprecated api
  69. }
  70. }