You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NeuralNetXor.cs 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using NumSharp;
  5. using Tensorflow;
  6. using TensorFlowNET.Examples.Utility;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// Simple vanilla neural net solving the famous XOR problem
  11. /// https://github.com/amygdala/tensorflow-workshop/blob/master/workshop_sections/getting_started/xor/README.md
  12. /// </summary>
  13. public class NeuralNetXor : Python, IExample
  14. {
  15. public int Priority => 10;
  16. public bool Enabled { get; set; } = true;
  17. public string Name => "NN XOR";
  18. public bool ImportGraph { get; set; } = true;
  19. public int num_steps = 5000;
  20. private NDArray data;
  21. private (Operation, Tensor, Tensor) make_graph(Tensor features,Tensor labels, int num_hidden = 8)
  22. {
  23. var stddev = 1 / Math.Sqrt(2);
  24. var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev ));
  25. // Shape [4, num_hidden]
  26. var hidden_activations = tf.nn.relu(tf.matmul(features, hidden_weights));
  27. var output_weights = tf.Variable(tf.truncated_normal(
  28. new[] {num_hidden, 1},
  29. stddev: (float) (1 / Math.Sqrt(num_hidden))
  30. ));
  31. // Shape [4, 1]
  32. var logits = tf.matmul(hidden_activations, output_weights);
  33. // Shape [4]
  34. var predictions = tf.sigmoid(tf.squeeze(logits));
  35. var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)), name:"loss");
  36. var gs = tf.Variable(0, trainable: false);
  37. var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs);
  38. return (train_op, loss, gs);
  39. }
  40. public bool Run()
  41. {
  42. PrepareData();
  43. float loss_value = 0;
  44. if (ImportGraph)
  45. loss_value = RunWithImportedGraph();
  46. else
  47. loss_value=RunWithBuiltGraph();
  48. return loss_value < 0.0628;
  49. }
  50. private float RunWithImportedGraph()
  51. {
  52. var graph = tf.Graph().as_default();
  53. tf.train.import_meta_graph("graph/xor.meta");
  54. Tensor features = graph.get_operation_by_name("Placeholder");
  55. Tensor labels = graph.get_operation_by_name("Placeholder_1");
  56. Tensor loss = graph.get_operation_by_name("loss");
  57. Tensor train_op = graph.get_operation_by_name("train_op");
  58. Tensor global_step = graph.get_operation_by_name("global_step");
  59. var init = tf.global_variables_initializer();
  60. float loss_value = 0;
  61. // Start tf session
  62. with<Session>(tf.Session(graph), sess =>
  63. {
  64. sess.run(init);
  65. var step = 0;
  66. var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32);
  67. while (step < num_steps)
  68. {
  69. // original python:
  70. //_, step, loss_value = sess.run(
  71. // [train_op, gs, loss],
  72. // feed_dict={features: xy, labels: y_}
  73. // )
  74. var result = sess.run(new ITensorOrOperation[] { train_op, global_step, loss }, new FeedItem(features, data), new FeedItem(labels, y_));
  75. loss_value = result[2];
  76. step++;
  77. if (step % 1000 == 0)
  78. Console.WriteLine($"Step {step} loss: {loss_value}");
  79. }
  80. Console.WriteLine($"Final loss: {loss_value}");
  81. });
  82. return loss_value;
  83. }
  84. private float RunWithBuiltGraph()
  85. {
  86. var graph = tf.Graph().as_default();
  87. var features = tf.placeholder(tf.float32, new TensorShape(4, 2));
  88. var labels = tf.placeholder(tf.int32, new TensorShape(4));
  89. var (train_op, loss, gs) = make_graph(features, labels);
  90. var init = tf.global_variables_initializer();
  91. float loss_value = 0;
  92. // Start tf session
  93. with(tf.Session(graph), sess =>
  94. {
  95. sess.run(init);
  96. var step = 0;
  97. var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32);
  98. while (step < num_steps)
  99. {
  100. // original python:
  101. //_, step, loss_value = sess.run(
  102. // [train_op, gs, loss],
  103. // feed_dict={features: xy, labels: y_}
  104. // )
  105. var result = sess.run(new ITensorOrOperation[] { train_op, gs, loss }, new FeedItem(features, data), new FeedItem(labels, y_));
  106. loss_value = result[2];
  107. //step = result[1];
  108. step++;
  109. if (step % 1000 == 0)
  110. Console.WriteLine($"Step {step} loss: {loss_value}");
  111. }
  112. Console.WriteLine($"Final loss: {loss_value}");
  113. });
  114. return loss_value;
  115. }
  116. public void PrepareData()
  117. {
  118. data = new float[,]
  119. {
  120. {1, 0 },
  121. {1, 1 },
  122. {0, 0 },
  123. {0, 1 }
  124. };
  125. // download graph meta data
  126. string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta";
  127. Web.Download(url, "graph", "xor.meta");
  128. }
  129. }
  130. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。