You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NaiveBayesClassifier.cs 8.9 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow;
  5. using NumSharp;
  6. using System.Linq;
  7. using static Tensorflow.Python;
  8. using System.IO;
  9. using TensorFlowNET.Examples.Utility;
  10. namespace TensorFlowNET.Examples
  11. {
  12. /// <summary>
  13. /// https://github.com/nicolov/naive_bayes_tensorflow
  14. /// </summary>
  15. public class NaiveBayesClassifier : IExample
  16. {
  17. public bool Enabled { get; set; } = true;
  18. public string Name => "Naive Bayes Classifier";
  19. public bool IsImportingGraph { get; set; } = false;
  20. public NDArray X, y;
  21. public Normal dist { get; set; }
  22. public bool Run()
  23. {
  24. PrepareData();
  25. fit(X, y);
  26. // Create a regular grid and classify each point
  27. float x_min = X.amin(0).Data<float>(0) - 0.5f;
  28. float y_min = X.amin(0).Data<float>(1) - 0.5f;
  29. float x_max = X.amax(0).Data<float>(0) + 0.5f;
  30. float y_max = X.amax(0).Data<float>(1) + 0.5f;
  31. var (xx, yy) = np.meshgrid(np.linspace(x_min, x_max, 30), np.linspace(y_min, y_max, 30));
  32. with(tf.Session(), sess =>
  33. {
  34. //var samples = np.vstack<float>(xx.ravel(), yy.ravel());
  35. //samples = np.transpose(samples);
  36. var array = np.Load<double[,]>(Path.Join("nb", "nb_example.npy"));
  37. var samples = np.array(array).astype(np.float32);
  38. var Z = sess.run(predict(samples));
  39. });
  40. return true;
  41. }
  42. public void fit(NDArray X, NDArray y)
  43. {
  44. var unique_y = y.unique<int>();
  45. var dic = new Dictionary<int, List<List<float>>>();
  46. // Init uy in dic
  47. foreach (int uy in unique_y.Data<int>())
  48. {
  49. dic.Add(uy, new List<List<float>>());
  50. }
  51. // Separate training points by class
  52. // Shape : nb_classes * nb_samples * nb_features
  53. int maxCount = 0;
  54. for (int i = 0; i < y.size; i++)
  55. {
  56. var curClass = y[i];
  57. var l = dic[curClass];
  58. var pair = new List<float>();
  59. pair.Add(X[i,0]);
  60. pair.Add(X[i, 1]);
  61. l.Add(pair);
  62. if (l.Count > maxCount)
  63. {
  64. maxCount = l.Count;
  65. }
  66. dic[curClass] = l;
  67. }
  68. float[,,] points = new float[dic.Count, maxCount, X.shape[1]];
  69. foreach (KeyValuePair<int, List<List<float>>> kv in dic)
  70. {
  71. int j = (int) kv.Key;
  72. for (int i = 0; i < maxCount; i++)
  73. {
  74. for (int k = 0; k < X.shape[1]; k++)
  75. {
  76. points[j, i, k] = kv.Value[i][k];
  77. }
  78. }
  79. }
  80. var points_by_class = np.array(points);
  81. // estimate mean and variance for each class / feature
  82. // shape : nb_classes * nb_features
  83. var cons = tf.constant(points_by_class);
  84. var tup = tf.nn.moments(cons, new int[]{1});
  85. var mean = tup.Item1;
  86. var variance = tup.Item2;
  87. // Create a 3x2 univariate normal distribution with the
  88. // Known mean and variance
  89. var dist = tf.distributions.Normal(mean, tf.sqrt(variance));
  90. this.dist = dist;
  91. }
  92. public Tensor predict(NDArray X)
  93. {
  94. if (dist == null)
  95. {
  96. throw new ArgumentNullException("cant not find the model (normal distribution)!");
  97. }
  98. int nb_classes = (int) dist.scale().shape[0];
  99. int nb_features = (int)dist.scale().shape[1];
  100. // Conditional probabilities log P(x|c) with shape
  101. // (nb_samples, nb_classes)
  102. var t1= ops.convert_to_tensor(X, TF_DataType.TF_FLOAT);
  103. var t2 = ops.convert_to_tensor(new int[] { 1, nb_classes });
  104. Tensor tile = tf.tile(t1, t2);
  105. var t3 = ops.convert_to_tensor(new int[] { -1, nb_classes, nb_features });
  106. Tensor r = tf.reshape(tile, t3);
  107. var cond_probs = tf.reduce_sum(dist.log_prob(r), 2);
  108. // uniform priors
  109. float[] tem = new float[nb_classes];
  110. for (int i = 0; i < tem.Length; i++)
  111. {
  112. tem[i] = 1.0f / nb_classes;
  113. }
  114. var priors = np.log(np.array<float>(tem));
  115. // posterior log probability, log P(c) + log P(x|c)
  116. var joint_likelihood = tf.add(ops.convert_to_tensor(priors, TF_DataType.TF_FLOAT), cond_probs);
  117. // normalize to get (log)-probabilities
  118. var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, keepdims: true);
  119. var log_prob = joint_likelihood - norm_factor;
  120. // exp to get the actual probabilities
  121. return tf.exp(log_prob);
  122. }
  123. public void PrepareData()
  124. {
  125. #region Training data
  126. X = np.array(new float[,] {
  127. {5.1f, 3.5f}, {4.9f, 3.0f}, {4.7f, 3.2f}, {4.6f, 3.1f}, {5.0f, 3.6f}, {5.4f, 3.9f},
  128. {4.6f, 3.4f}, {5.0f, 3.4f}, {4.4f, 2.9f}, {4.9f, 3.1f}, {5.4f, 3.7f}, {4.8f, 3.4f},
  129. {4.8f, 3.0f}, {4.3f, 3.0f}, {5.8f, 4.0f}, {5.7f, 4.4f}, {5.4f, 3.9f}, {5.1f, 3.5f},
  130. {5.7f, 3.8f}, {5.1f, 3.8f}, {5.4f, 3.4f}, {5.1f, 3.7f}, {5.1f, 3.3f}, {4.8f, 3.4f},
  131. {5.0f, 3.0f}, {5.0f, 3.4f}, {5.2f, 3.5f}, {5.2f, 3.4f}, {4.7f, 3.2f}, {4.8f, 3.1f},
  132. {5.4f, 3.4f}, {5.2f, 4.1f}, {5.5f, 4.2f}, {4.9f, 3.1f}, {5.0f, 3.2f}, {5.5f, 3.5f},
  133. {4.9f, 3.6f}, {4.4f, 3.0f}, {5.1f, 3.4f}, {5.0f, 3.5f}, {4.5f, 2.3f}, {4.4f, 3.2f},
  134. {5.0f, 3.5f}, {5.1f, 3.8f}, {4.8f, 3.0f}, {5.1f, 3.8f}, {4.6f, 3.2f}, {5.3f, 3.7f},
  135. {5.0f, 3.3f}, {7.0f, 3.2f}, {6.4f, 3.2f}, {6.9f, 3.1f}, {5.5f, 2.3f}, {6.5f, 2.8f},
  136. {5.7f, 2.8f}, {6.3f, 3.3f}, {4.9f, 2.4f}, {6.6f, 2.9f}, {5.2f, 2.7f}, {5.0f, 2.0f},
  137. {5.9f, 3.0f}, {6.0f, 2.2f}, {6.1f, 2.9f}, {5.6f, 2.9f}, {6.7f, 3.1f}, {5.6f, 3.0f},
  138. {5.8f, 2.7f}, {6.2f, 2.2f}, {5.6f, 2.5f}, {5.9f, 3.0f}, {6.1f, 2.8f}, {6.3f, 2.5f},
  139. {6.1f, 2.8f}, {6.4f, 2.9f}, {6.6f, 3.0f}, {6.8f, 2.8f}, {6.7f, 3.0f}, {6.0f, 2.9f},
  140. {5.7f, 2.6f}, {5.5f, 2.4f}, {5.5f, 2.4f}, {5.8f, 2.7f}, {6.0f, 2.7f}, {5.4f, 3.0f},
  141. {6.0f, 3.4f}, {6.7f, 3.1f}, {6.3f, 2.3f}, {5.6f, 3.0f}, {5.5f, 2.5f}, {5.5f, 2.6f},
  142. {6.1f, 3.0f}, {5.8f, 2.6f}, {5.0f, 2.3f}, {5.6f, 2.7f}, {5.7f, 3.0f}, {5.7f, 2.9f},
  143. {6.2f, 2.9f}, {5.1f, 2.5f}, {5.7f, 2.8f}, {6.3f, 3.3f}, {5.8f, 2.7f}, {7.1f, 3.0f},
  144. {6.3f, 2.9f}, {6.5f, 3.0f}, {7.6f, 3.0f}, {4.9f, 2.5f}, {7.3f, 2.9f}, {6.7f, 2.5f},
  145. {7.2f, 3.6f}, {6.5f, 3.2f}, {6.4f, 2.7f}, {6.8f, 3.0f}, {5.7f, 2.5f}, {5.8f, 2.8f},
  146. {6.4f, 3.2f}, {6.5f, 3.0f}, {7.7f, 3.8f}, {7.7f, 2.6f}, {6.0f, 2.2f}, {6.9f, 3.2f},
  147. {5.6f, 2.8f}, {7.7f, 2.8f}, {6.3f, 2.7f}, {6.7f, 3.3f}, {7.2f, 3.2f}, {6.2f, 2.8f},
  148. {6.1f, 3.0f}, {6.4f, 2.8f}, {7.2f, 3.0f}, {7.4f, 2.8f}, {7.9f, 3.8f}, {6.4f, 2.8f},
  149. {6.3f, 2.8f}, {6.1f, 2.6f}, {7.7f, 3.0f}, {6.3f, 3.4f}, {6.4f, 3.1f}, {6.0f, 3.0f},
  150. {6.9f, 3.1f}, {6.7f, 3.1f}, {6.9f, 3.1f}, {5.8f, 2.7f}, {6.8f, 3.2f}, {6.7f, 3.3f},
  151. {6.7f, 3.0f}, {6.3f, 2.5f}, {6.5f, 3.0f}, {6.2f, 3.4f}, {5.9f, 3.0f}, {5.8f, 3.0f}});
  152. y = np.array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  153. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  154. 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  155. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  156. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  157. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  158. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2);
  159. string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/nb_example.npy";
  160. Web.Download(url, "nb", "nb_example.npy");
  161. #endregion
  162. }
  163. public Graph ImportGraph()
  164. {
  165. throw new NotImplementedException();
  166. }
  167. public Graph BuildGraph()
  168. {
  169. throw new NotImplementedException();
  170. }
  171. public bool Train()
  172. {
  173. throw new NotImplementedException();
  174. }
  175. public bool Predict()
  176. {
  177. throw new NotImplementedException();
  178. }
  179. }
  180. }