You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Newtonsoft.Json;
  3. using System;
  4. using Tensorflow;
  5. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  6. {
  7. /// <summary>
  8. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  9. /// </summary>
  10. [TestClass]
  11. public class CondTestCases : PythonTest
  12. {
  13. [TestMethod]
  14. public void testCondTrue_ConstOnly()
  15. {
  16. var graph = tf.Graph().as_default();
  17. with(tf.Session(graph), sess =>
  18. {
  19. var x = tf.constant(2, name: "x");
  20. var y = tf.constant(5, name: "y");
  21. var z = control_flow_ops.cond(tf.less(x, y),
  22. () => tf.constant(22, name: "t2"),
  23. () => tf.constant(55, name: "f5"));
  24. int result = z.eval(sess);
  25. assertEquals(result, 22);
  26. });
  27. }
  28. [TestMethod]
  29. public void testCondFalse_ConstOnly()
  30. {
  31. var graph = tf.Graph().as_default();
  32. with(tf.Session(graph), sess =>
  33. {
  34. var x = tf.constant(2, name: "x");
  35. var y = tf.constant(1, name: "y");
  36. var z = control_flow_ops.cond(tf.less(x, y),
  37. () => tf.constant(22, name: "t2"),
  38. () => tf.constant(11, name: "f1"));
  39. int result = z.eval(sess);
  40. assertEquals(result, 11);
  41. });
  42. }
  43. [TestMethod]
  44. public void testCondTrue()
  45. {
  46. var graph = tf.Graph().as_default();
  47. with(tf.Session(graph), sess =>
  48. {
  49. var x = tf.constant(2);
  50. var y = tf.constant(5);
  51. var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
  52. () => tf.add(y, tf.constant(23)));
  53. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
  54. int result = z.eval(sess);
  55. assertEquals(result, 34);
  56. });
  57. }
  58. //[Ignore("This Test Fails due to missing edges in the graph!")]
  59. [TestMethod]
  60. public void testCondFalse()
  61. {
  62. var graph = tf.Graph().as_default();
  63. with(tf.Session(graph), sess =>
  64. {
  65. var x = tf.constant(2);
  66. var y = tf.constant(1);
  67. var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
  68. () => tf.add(y, tf.constant(23)));
  69. int result = z.eval(sess);
  70. assertEquals(result, 24);
  71. });
  72. }
  73. // NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api
  74. }
  75. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。