You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NearestNeighbor.cs 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. using TensorFlowNET.Examples.Utility;
  7. using static Tensorflow.Python;
  8. namespace TensorFlowNET.Examples
  9. {
  10. /// <summary>
  11. /// A nearest neighbor learning algorithm example
  12. /// This example is using the MNIST database of handwritten digits
  13. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py
  14. /// </summary>
  15. public class NearestNeighbor : IExample
  16. {
  17. public bool Enabled { get; set; } = true;
  18. public string Name => "Nearest Neighbor";
  19. Datasets mnist;
  20. NDArray Xtr, Ytr, Xte, Yte;
  21. public int? TrainSize = null;
  22. public int ValidationSize = 5000;
  23. public int? TestSize = null;
  24. public bool IsImportingGraph { get; set; } = false;
  25. public bool Run()
  26. {
  27. // tf Graph Input
  28. var xtr = tf.placeholder(tf.float32, new TensorShape(-1, 784));
  29. var xte = tf.placeholder(tf.float32, new TensorShape(784));
  30. // Nearest Neighbor calculation using L1 Distance
  31. // Calculate L1 Distance
  32. var distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices: 1);
  33. // Prediction: Get min distance index (Nearest neighbor)
  34. var pred = tf.arg_min(distance, 0);
  35. float accuracy = 0f;
  36. // Initialize the variables (i.e. assign their default value)
  37. var init = tf.global_variables_initializer();
  38. with(tf.Session(), sess =>
  39. {
  40. // Run the initializer
  41. sess.run(init);
  42. PrepareData();
  43. foreach(int i in range(Xte.shape[0]))
  44. {
  45. // Get nearest neighbor
  46. long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i]));
  47. // Get nearest neighbor class label and compare it to its true label
  48. int index = (int)nn_index;
  49. if (i % 10 == 0 || i == 0)
  50. print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}");
  51. // Calculate accuracy
  52. if ((int)np.argmax(Ytr[index]) == (int)np.argmax(Yte[i]))
  53. accuracy += 1f/ Xte.shape[0];
  54. }
  55. print($"Accuracy: {accuracy}");
  56. });
  57. return accuracy > 0.8;
  58. }
  59. public void PrepareData()
  60. {
  61. mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize);
  62. // In this example, we limit mnist data
  63. (Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
  64. (Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
  65. }
  66. public Graph ImportGraph()
  67. {
  68. throw new NotImplementedException();
  69. }
  70. public Graph BuildGraph()
  71. {
  72. throw new NotImplementedException();
  73. }
  74. public bool Train()
  75. {
  76. throw new NotImplementedException();
  77. }
  78. public bool Predict()
  79. {
  80. throw new NotImplementedException();
  81. }
  82. }
  83. }