You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LogisticRegression.cs 6.7 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Text;
  8. using Tensorflow;
  9. using TensorFlowNET.Examples.Utility;
  10. using static Tensorflow.Python;
  11. namespace TensorFlowNET.Examples
  12. {
  13. /// <summary>
  14. /// A logistic regression learning algorithm example using TensorFlow library.
  15. /// This example is using the MNIST database of handwritten digits
  16. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py
  17. /// </summary>
  18. public class LogisticRegression : IExample
  19. {
  20. public bool Enabled { get; set; } = true;
  21. public string Name => "Logistic Regression";
  22. public bool IsImportingGraph { get; set; } = false;
  23. public int training_epochs = 10;
  24. public int? train_size = null;
  25. public int validation_size = 5000;
  26. public int? test_size = null;
  27. public int batch_size = 100;
  28. private float learning_rate = 0.01f;
  29. private int display_step = 1;
  30. Datasets<DataSetMnist> mnist;
  31. public bool Run()
  32. {
  33. PrepareData();
  34. // tf Graph Input
  35. var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
  36. var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
  37. // Set model weights
  38. var W = tf.Variable(tf.zeros(new Shape(784, 10)));
  39. var b = tf.Variable(tf.zeros(new Shape(10)));
  40. // Construct model
  41. var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax
  42. // Minimize error using cross entropy
  43. var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1));
  44. // Gradient Descent
  45. var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
  46. // Initialize the variables (i.e. assign their default value)
  47. var init = tf.global_variables_initializer();
  48. var sw = new Stopwatch();
  49. return with(tf.Session(), sess =>
  50. {
  51. // Run the initializer
  52. sess.run(init);
  53. // Training cycle
  54. foreach (var epoch in range(training_epochs))
  55. {
  56. sw.Start();
  57. var avg_cost = 0.0f;
  58. var total_batch = mnist.train.num_examples / batch_size;
  59. // Loop over all batches
  60. foreach (var i in range(total_batch))
  61. {
  62. var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
  63. // Run optimization op (backprop) and cost op (to get loss value)
  64. var result = sess.run(new object[] { optimizer, cost },
  65. new FeedItem(x, batch_xs),
  66. new FeedItem(y, batch_ys));
  67. float c = result[1];
  68. // Compute average loss
  69. avg_cost += c / total_batch;
  70. }
  71. sw.Stop();
  72. // Display logs per epoch step
  73. if ((epoch + 1) % display_step == 0)
  74. print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms");
  75. sw.Reset();
  76. }
  77. print("Optimization Finished!");
  78. // SaveModel(sess);
  79. // Test model
  80. var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
  81. // Calculate accuracy
  82. var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
  83. float acc = accuracy.eval(new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
  84. print($"Accuracy: {acc.ToString("F4")}");
  85. return acc > 0.9;
  86. });
  87. }
  88. public void PrepareData()
  89. {
  90. mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size);
  91. }
  92. public void SaveModel(Session sess)
  93. {
  94. var saver = tf.train.Saver();
  95. var save_path = saver.save(sess, "logistic_regression/model.ckpt");
  96. tf.train.write_graph(sess.graph, "logistic_regression", "model.pbtxt", as_text: true);
  97. FreezeGraph.freeze_graph(input_graph: "logistic_regression/model.pbtxt",
  98. input_saver: "",
  99. input_binary: false,
  100. input_checkpoint: "logistic_regression/model.ckpt",
  101. output_node_names: "Softmax",
  102. restore_op_name: "save/restore_all",
  103. filename_tensor_name: "save/Const:0",
  104. output_graph: "logistic_regression/model.pb",
  105. clear_devices: true,
  106. initializer_nodes: "");
  107. }
  108. public void Predict(Session sess)
  109. {
  110. var graph = new Graph().as_default();
  111. graph.Import(Path.Join("logistic_regression", "model.pb"));
  112. // restoring the model
  113. // var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta");
  114. // saver.restore(sess, tf.train.latest_checkpoint('logistic_regression'));
  115. var pred = graph.OperationByName("Softmax");
  116. var output = pred.outputs[0];
  117. var x = graph.OperationByName("Placeholder");
  118. var input = x.outputs[0];
  119. // predict
  120. var (batch_xs, batch_ys) = mnist.train.next_batch(10);
  121. var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)]));
  122. if (results.argmax() == (batch_ys[0] as NDArray).argmax())
  123. print("predicted OK!");
  124. else
  125. throw new ValueError("predict error, should be 90% accuracy");
  126. }
  127. public Graph ImportGraph()
  128. {
  129. throw new NotImplementedException();
  130. }
  131. public Graph BuildGraph()
  132. {
  133. throw new NotImplementedException();
  134. }
  135. public void Train(Session sess)
  136. {
  137. throw new NotImplementedException();
  138. }
  139. public void Test(Session sess)
  140. {
  141. throw new NotImplementedException();
  142. }
  143. }
  144. }