You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 3.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Newtonsoft.Json;
  3. using System;
  4. using Tensorflow;
  5. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  6. {
  7. /// <summary>
  8. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  9. /// </summary>
  10. [TestClass]
  11. public class CondTestCases : PythonTest
  12. {
  13. [TestMethod]
  14. public void testCondTrue()
  15. {
  16. var graph = tf.Graph().as_default();
  17. // tf.train.import_meta_graph("cond_test.meta");
  18. var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
  19. with(tf.Session(graph), sess =>
  20. {
  21. var x = tf.constant(2, name: "x"); // graph.get_operation_by_name("Const").output;
  22. var y = tf.constant(5, name: "y"); // graph.get_operation_by_name("Const_1").output;
  23. var pred = tf.less(x, y); // graph.get_operation_by_name("Less").output;
  24. Func<ITensorOrOperation> if_true = delegate
  25. {
  26. return tf.constant(2, name: "t2");
  27. };
  28. Func<ITensorOrOperation> if_false = delegate
  29. {
  30. return tf.constant(5, name: "f5");
  31. };
  32. var z = control_flow_ops.cond(pred, if_true, if_false); // graph.get_operation_by_name("cond/Merge").output
  33. json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
  34. int result = z.eval(sess);
  35. assertEquals(result, 2);
  36. });
  37. }
  38. [TestMethod]
  39. public void testCondFalse()
  40. {
  41. /* python
  42. * import tensorflow as tf
  43. from tensorflow.python.framework import ops
  44. def if_true():
  45. return tf.math.multiply(x, 17)
  46. def if_false():
  47. return tf.math.add(y, 23)
  48. with tf.Session() as sess:
  49. x = tf.constant(2)
  50. y = tf.constant(1)
  51. pred = tf.math.less(x,y)
  52. z = tf.cond(pred, if_true, if_false)
  53. result = z.eval()
  54. print(result == 24) */
  55. var graph = tf.Graph().as_default();
  56. //tf.train.import_meta_graph("cond_test.meta");
  57. //var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
  58. with(tf.Session(), sess =>
  59. {
  60. var x = tf.constant(2, name: "x");
  61. var y = tf.constant(1, name: "y");
  62. var pred = tf.less(x, y);
  63. Func<ITensorOrOperation> if_true = delegate
  64. {
  65. return tf.constant(2, name: "t2");
  66. };
  67. Func<ITensorOrOperation> if_false = delegate
  68. {
  69. return tf.constant(1, name: "f1");
  70. };
  71. var z = control_flow_ops.cond(pred, if_true, if_false);
  72. var json1 = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented);
  73. int result = z.eval(sess);
  74. assertEquals(result, 1);
  75. });
  76. }
  77. // NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api
  78. }
  79. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。