You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

KMeans.cs 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections.Generic;
  15. using System.Text;
  16. using static Tensorflow.Python;
  17. namespace Tensorflow.Clustering
  18. {
  19. /// <summary>
  20. /// Creates the graph for k-means clustering.
  21. /// </summary>
  22. public class KMeans
  23. {
  24. public const string CLUSTERS_VAR_NAME = "clusters";
  25. public const string SQUARED_EUCLIDEAN_DISTANCE = "squared_euclidean";
  26. public const string COSINE_DISTANCE = "cosine";
  27. public const string RANDOM_INIT = "random";
  28. public const string KMEANS_PLUS_PLUS_INIT = "kmeans_plus_plus";
  29. public const string KMC2_INIT = "kmc2";
  30. Tensor[] _inputs;
  31. int _num_clusters;
  32. string _initial_clusters;
  33. string _distance_metric;
  34. bool _use_mini_batch;
  35. int _mini_batch_steps_per_iteration;
  36. int _random_seed;
  37. int _kmeans_plus_plus_num_retries;
  38. int _kmc2_chain_length;
  39. public KMeans(Tensor inputs,
  40. int num_clusters,
  41. string initial_clusters = RANDOM_INIT,
  42. string distance_metric = SQUARED_EUCLIDEAN_DISTANCE,
  43. bool use_mini_batch = false,
  44. int mini_batch_steps_per_iteration = 1,
  45. int random_seed = 0,
  46. int kmeans_plus_plus_num_retries = 2,
  47. int kmc2_chain_length = 200)
  48. {
  49. _inputs = new Tensor[] { inputs };
  50. _num_clusters = num_clusters;
  51. _initial_clusters = initial_clusters;
  52. _distance_metric = distance_metric;
  53. _use_mini_batch = use_mini_batch;
  54. _mini_batch_steps_per_iteration = mini_batch_steps_per_iteration;
  55. _random_seed = random_seed;
  56. _kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries;
  57. _kmc2_chain_length = kmc2_chain_length;
  58. }
  59. public object training_graph()
  60. {
  61. var initial_clusters = _initial_clusters;
  62. var num_clusters = ops.convert_to_tensor(_num_clusters);
  63. var inputs = _inputs;
  64. var vars = _create_variables(num_clusters);
  65. var cluster_centers_var = vars[0];
  66. var cluster_centers_initialized = vars[1];
  67. var total_counts = vars[2];
  68. var cluster_centers_updated = vars[3];
  69. var update_in_steps = vars[4];
  70. var init_op = new _InitializeClustersOpFactory(_inputs, num_clusters, initial_clusters, _distance_metric,
  71. _random_seed, _kmeans_plus_plus_num_retries,
  72. _kmc2_chain_length, cluster_centers_var, cluster_centers_updated,
  73. cluster_centers_initialized).op();
  74. throw new NotImplementedException("KMeans training_graph");
  75. }
  76. private RefVariable[] _create_variables(Tensor num_clusters)
  77. {
  78. var init_value = constant_op.constant(new float[0], dtype: TF_DataType.TF_FLOAT);
  79. var cluster_centers = tf.Variable(init_value, name: CLUSTERS_VAR_NAME, validate_shape: false);
  80. var cluster_centers_initialized = tf.Variable(false, dtype: TF_DataType.TF_BOOL, name: "initialized");
  81. RefVariable update_in_steps = null;
  82. if (_use_mini_batch && _mini_batch_steps_per_iteration > 1)
  83. throw new NotImplementedException("KMeans._create_variables");
  84. else
  85. {
  86. var cluster_centers_updated = cluster_centers;
  87. var ones = array_ops.ones(new Tensor[] { num_clusters }, dtype: TF_DataType.TF_INT64);
  88. var cluster_counts = _use_mini_batch ? tf.Variable(ones) : null;
  89. return new RefVariable[]
  90. {
  91. cluster_centers,
  92. cluster_centers_initialized,
  93. cluster_counts,
  94. cluster_centers_updated,
  95. update_in_steps
  96. };
  97. }
  98. }
  99. }
  100. }