You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

KMeansClustering.cs 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow;
  8. using Tensorflow.Clustering;
  9. using TensorFlowNET.Examples.Utility;
  10. namespace TensorFlowNET.Examples
  11. {
  12. /// <summary>
  13. /// Implement K-Means algorithm with TensorFlow.NET, and apply it to classify
  14. /// handwritten digit images.
  15. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/kmeans.py
  16. /// </summary>
  17. public class KMeansClustering : Python, IExample
  18. {
  19. public int Priority => 8;
  20. public bool Enabled { get; set; } = true;
  21. public string Name => "K-means Clustering";
  22. public bool ImportGraph { get; set; } = true;
  23. public int? train_size = null;
  24. public int validation_size = 5000;
  25. public int? test_size = null;
  26. public int batch_size = 1024; // The number of samples per batch
  27. Datasets mnist;
  28. NDArray full_data_x;
  29. int num_steps = 20; // Total steps to train
  30. int k = 25; // The number of clusters
  31. int num_classes = 10; // The 10 digits
  32. int num_features = 784; // Each image is 28x28 pixels
  33. public bool Run()
  34. {
  35. PrepareData();
  36. var graph = tf.Graph().as_default();
  37. tf.train.import_meta_graph("graph/kmeans.meta");
  38. // Input images
  39. Tensor X = graph.get_operation_by_name("Placeholder"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
  40. // Labels (for assigning a label to a centroid and testing)
  41. Tensor Y = graph.get_operation_by_name("Placeholder_1"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));
  42. // K-Means Parameters
  43. //var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true);
  44. // Build KMeans graph
  45. //var training_graph = kmeans.training_graph();
  46. var init_vars = tf.global_variables_initializer();
  47. Tensor init_op = graph.get_operation_by_name("cond/Merge");
  48. var train_op = graph.get_operation_by_name("group_deps");
  49. Tensor avg_distance = graph.get_operation_by_name("Mean");
  50. Tensor cluster_idx = graph.get_operation_by_name("Squeeze_1");
  51. NDArray result = null;
  52. with(tf.Session(graph), sess =>
  53. {
  54. sess.run(init_vars, new FeedItem(X, full_data_x));
  55. sess.run(init_op, new FeedItem(X, full_data_x));
  56. // Training
  57. var sw = new Stopwatch();
  58. foreach (var i in range(1, num_steps + 1))
  59. {
  60. sw.Restart();
  61. result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x));
  62. sw.Stop();
  63. if (i % 4 == 0 || i == 1)
  64. print($"Step {i}, Avg Distance: {result[1]} Elapse: {sw.ElapsedMilliseconds}ms");
  65. }
  66. var idx = result[2].Data<int>();
  67. // Assign a label to each centroid
  68. // Count total number of labels per centroid, using the label of each training
  69. // sample to their closest centroid (given by 'idx')
  70. var counts = np.zeros((k, num_classes), np.float32);
  71. sw.Start();
  72. foreach (var i in range(idx.Length))
  73. {
  74. var x = mnist.train.labels[i];
  75. counts[idx[i]] += x;
  76. }
  77. sw.Stop();
  78. print($"Assign a label to each centroid took {sw.ElapsedMilliseconds}ms");
  79. // Assign the most frequent label to the centroid
  80. var labels_map_array = np.argmax(counts, 1);
  81. var labels_map = tf.convert_to_tensor(labels_map_array);
  82. // Evaluation ops
  83. // Lookup: centroid_id -> label
  84. var cluster_label = tf.nn.embedding_lookup(labels_map, cluster_idx);
  85. // Compute accuracy
  86. var correct_prediction = tf.equal(cluster_label, tf.cast(tf.argmax(Y, 1), tf.int32));
  87. var cast = tf.cast(correct_prediction, tf.float32);
  88. var accuracy_op = tf.reduce_mean(cast);
  89. // Test Model
  90. var (test_x, test_y) = (mnist.test.images, mnist.test.labels);
  91. result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y));
  92. print($"Test Accuracy: {result}");
  93. });
  94. return (float)result > 0.70;
  95. }
  96. public void PrepareData()
  97. {
  98. mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size);
  99. full_data_x = mnist.train.images;
  100. // download graph meta data
  101. string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta";
  102. Web.Download(url, "graph", "kmeans.meta");
  103. }
  104. }
  105. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。