You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NeuralNetXor.cs 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using NumSharp;
  5. using Tensorflow;
  6. namespace TensorFlowNET.Examples
  7. {
  8. /// <summary>
  9. /// Simple vanilla neural net solving the famous XOR problem
  10. /// https://github.com/amygdala/tensorflow-workshop/blob/master/workshop_sections/getting_started/xor/README.md
  11. /// </summary>
  12. public class NeuralNetXor : Python, IExample
  13. {
  14. public int Priority => 2;
  15. public bool Enabled { get; set; } = true;
  16. public string Name => "NN XOR";
  17. public int num_steps = 5000;
  18. private (Operation, Tensor, RefVariable) make_graph(Tensor features,Tensor labels, int num_hidden = 8)
  19. {
  20. var stddev = 1 / Math.Sqrt(2);
  21. var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev ));
  22. // Shape [4, num_hidden]
  23. var hidden_activations = tf.nn.relu(tf.matmul(features, hidden_weights));
  24. var output_weights = tf.Variable(tf.truncated_normal(
  25. new[] {num_hidden, 1},
  26. stddev: (float) (1 / Math.Sqrt(num_hidden))
  27. ));
  28. // Shape [4, 1]
  29. var logits = tf.matmul(hidden_activations, output_weights);
  30. // Shape [4]
  31. var predictions = tf.sigmoid(tf.squeeze(logits));
  32. var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)));
  33. var gs = tf.Variable(0, trainable: false);
  34. var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs);
  35. return (train_op, loss, gs);
  36. }
  37. public bool Run()
  38. {
  39. var graph = tf.Graph().as_default();
  40. var features = tf.placeholder(tf.float32, new TensorShape(4, 2));
  41. var labels = tf.placeholder(tf.int32, new TensorShape(4));
  42. var (train_op, loss, gs) = make_graph(features, labels);
  43. var init = tf.global_variables_initializer();
  44. // Start tf session
  45. with(tf.Session(graph), sess =>
  46. {
  47. init.run();
  48. var step = 0;
  49. //TODO: make the type conversion and jagged array initializer work with numpy
  50. //var xy = np.array(new bool[,]
  51. //{
  52. // {true, false},
  53. // {true, true },
  54. // {false, false },
  55. // {false, true},
  56. //}, dtype: np.float32);
  57. var xy = np.array(new float[]
  58. {
  59. 1, 0,
  60. 1, 1,
  61. 0, 0,
  62. 0, 1
  63. }, np.float32).reshape(4,2);
  64. //var y_ = np.array(new[] {true, false, false, true}, dtype: np.int32);
  65. var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32);
  66. NDArray loss_value=null;
  67. while (step < num_steps)
  68. {
  69. // original python:
  70. //_, step, loss_value = sess.run(
  71. // [train_op, gs, loss],
  72. // feed_dict={features: xy, labels: y_}
  73. // )
  74. loss_value = sess.run(loss, new FeedItem(features, xy), new FeedItem(labels, y_));
  75. step++;
  76. if (step%1000==0)
  77. Console.WriteLine($"Step {0} loss: {loss_value[0]}");
  78. }
  79. Console.WriteLine($"Final loss: {loss_value[0]}");
  80. });
  81. return true;
  82. }
  83. public void PrepareData()
  84. {
  85. }
  86. }
  87. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。