You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NaiveBayesClassifier.cs 9.8 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections.Generic;
  15. using Tensorflow;
  16. using NumSharp;
  17. using static Tensorflow.Binding;
  18. using System.IO;
  19. using TensorFlowNET.Examples.Utility;
  20. namespace TensorFlowNET.Examples
  21. {
  22. /// <summary>
  23. /// https://github.com/nicolov/naive_bayes_tensorflow
  24. /// </summary>
  25. public class NaiveBayesClassifier : IExample
  26. {
  27. public bool Enabled { get; set; } = true;
  28. public string Name => "Naive Bayes Classifier";
  29. public bool IsImportingGraph { get; set; } = false;
  30. public NDArray X, y;
  31. public Normal dist { get; set; }
  32. public bool Run()
  33. {
  34. PrepareData();
  35. fit(X, y);
  36. // Create a regular grid and classify each point
  37. float x_min = X.amin(0).Data<float>()[0] - 0.5f;
  38. float y_min = X.amin(0).Data<float>()[1] - 0.5f;
  39. float x_max = X.amax(0).Data<float>()[1] + 0.5f;
  40. float y_max = X.amax(0).Data<float>()[1] + 0.5f;
  41. var (xx, yy) = np.meshgrid(np.linspace(x_min, x_max, 30), np.linspace(y_min, y_max, 30));
  42. using (var sess = tf.Session())
  43. {
  44. //var samples = np.vstack<float>(xx.ravel(), yy.ravel());
  45. //samples = np.transpose(samples);
  46. var array = np.Load<double[,]>(Path.Join("nb", "nb_example.npy"));
  47. var samples = np.array(array).astype(np.float32);
  48. var Z = sess.run(predict(samples));
  49. }
  50. return true;
  51. }
  52. public void fit(NDArray X, NDArray y)
  53. {
  54. var unique_y = np.unique(y);
  55. var dic = new Dictionary<int, List<List<float>>>();
  56. // Init uy in dic
  57. foreach (int uy in unique_y.Data<int>())
  58. {
  59. dic.Add(uy, new List<List<float>>());
  60. }
  61. // Separate training points by class
  62. // Shape : nb_classes * nb_samples * nb_features
  63. int maxCount = 0;
  64. for (int i = 0; i < y.size; i++)
  65. {
  66. var curClass = y[i];
  67. var l = dic[curClass];
  68. var pair = new List<float>();
  69. pair.Add(X[i,0]);
  70. pair.Add(X[i, 1]);
  71. l.Add(pair);
  72. if (l.Count > maxCount)
  73. {
  74. maxCount = l.Count;
  75. }
  76. dic[curClass] = l;
  77. }
  78. float[,,] points = new float[dic.Count, maxCount, X.shape[1]];
  79. foreach (KeyValuePair<int, List<List<float>>> kv in dic)
  80. {
  81. int j = (int) kv.Key;
  82. for (int i = 0; i < maxCount; i++)
  83. {
  84. for (int k = 0; k < X.shape[1]; k++)
  85. {
  86. points[j, i, k] = kv.Value[i][k];
  87. }
  88. }
  89. }
  90. var points_by_class = np.array(points);
  91. // estimate mean and variance for each class / feature
  92. // shape : nb_classes * nb_features
  93. var cons = tf.constant(points_by_class);
  94. var tup = tf.nn.moments(cons, new int[]{1});
  95. var mean = tup.Item1;
  96. var variance = tup.Item2;
  97. // Create a 3x2 univariate normal distribution with the
  98. // Known mean and variance
  99. var dist = tf.distributions.Normal(mean, tf.sqrt(variance));
  100. this.dist = dist;
  101. }
  102. public Tensor predict(NDArray X)
  103. {
  104. if (dist == null)
  105. {
  106. throw new ArgumentNullException("cant not find the model (normal distribution)!");
  107. }
  108. int nb_classes = (int) dist.scale().shape[0];
  109. int nb_features = (int)dist.scale().shape[1];
  110. // Conditional probabilities log P(x|c) with shape
  111. // (nb_samples, nb_classes)
  112. var t1= ops.convert_to_tensor(X, TF_DataType.TF_FLOAT);
  113. var t2 = ops.convert_to_tensor(new int[] { 1, nb_classes });
  114. Tensor tile = tf.tile(t1, t2);
  115. var t3 = ops.convert_to_tensor(new int[] { -1, nb_classes, nb_features });
  116. Tensor r = tf.reshape(tile, t3);
  117. var cond_probs = tf.reduce_sum(dist.log_prob(r), 2);
  118. // uniform priors
  119. float[] tem = new float[nb_classes];
  120. for (int i = 0; i < tem.Length; i++)
  121. {
  122. tem[i] = 1.0f / nb_classes;
  123. }
  124. var priors = np.log(np.array<float>(tem));
  125. // posterior log probability, log P(c) + log P(x|c)
  126. var joint_likelihood = tf.add(ops.convert_to_tensor(priors, TF_DataType.TF_FLOAT), cond_probs);
  127. // normalize to get (log)-probabilities
  128. var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, keepdims: true);
  129. var log_prob = joint_likelihood - norm_factor;
  130. // exp to get the actual probabilities
  131. return tf.exp(log_prob);
  132. }
  133. public void PrepareData()
  134. {
  135. #region Training data
  136. X = np.array(new float[,] {
  137. {5.1f, 3.5f}, {4.9f, 3.0f}, {4.7f, 3.2f}, {4.6f, 3.1f}, {5.0f, 3.6f}, {5.4f, 3.9f},
  138. {4.6f, 3.4f}, {5.0f, 3.4f}, {4.4f, 2.9f}, {4.9f, 3.1f}, {5.4f, 3.7f}, {4.8f, 3.4f},
  139. {4.8f, 3.0f}, {4.3f, 3.0f}, {5.8f, 4.0f}, {5.7f, 4.4f}, {5.4f, 3.9f}, {5.1f, 3.5f},
  140. {5.7f, 3.8f}, {5.1f, 3.8f}, {5.4f, 3.4f}, {5.1f, 3.7f}, {5.1f, 3.3f}, {4.8f, 3.4f},
  141. {5.0f, 3.0f}, {5.0f, 3.4f}, {5.2f, 3.5f}, {5.2f, 3.4f}, {4.7f, 3.2f}, {4.8f, 3.1f},
  142. {5.4f, 3.4f}, {5.2f, 4.1f}, {5.5f, 4.2f}, {4.9f, 3.1f}, {5.0f, 3.2f}, {5.5f, 3.5f},
  143. {4.9f, 3.6f}, {4.4f, 3.0f}, {5.1f, 3.4f}, {5.0f, 3.5f}, {4.5f, 2.3f}, {4.4f, 3.2f},
  144. {5.0f, 3.5f}, {5.1f, 3.8f}, {4.8f, 3.0f}, {5.1f, 3.8f}, {4.6f, 3.2f}, {5.3f, 3.7f},
  145. {5.0f, 3.3f}, {7.0f, 3.2f}, {6.4f, 3.2f}, {6.9f, 3.1f}, {5.5f, 2.3f}, {6.5f, 2.8f},
  146. {5.7f, 2.8f}, {6.3f, 3.3f}, {4.9f, 2.4f}, {6.6f, 2.9f}, {5.2f, 2.7f}, {5.0f, 2.0f},
  147. {5.9f, 3.0f}, {6.0f, 2.2f}, {6.1f, 2.9f}, {5.6f, 2.9f}, {6.7f, 3.1f}, {5.6f, 3.0f},
  148. {5.8f, 2.7f}, {6.2f, 2.2f}, {5.6f, 2.5f}, {5.9f, 3.0f}, {6.1f, 2.8f}, {6.3f, 2.5f},
  149. {6.1f, 2.8f}, {6.4f, 2.9f}, {6.6f, 3.0f}, {6.8f, 2.8f}, {6.7f, 3.0f}, {6.0f, 2.9f},
  150. {5.7f, 2.6f}, {5.5f, 2.4f}, {5.5f, 2.4f}, {5.8f, 2.7f}, {6.0f, 2.7f}, {5.4f, 3.0f},
  151. {6.0f, 3.4f}, {6.7f, 3.1f}, {6.3f, 2.3f}, {5.6f, 3.0f}, {5.5f, 2.5f}, {5.5f, 2.6f},
  152. {6.1f, 3.0f}, {5.8f, 2.6f}, {5.0f, 2.3f}, {5.6f, 2.7f}, {5.7f, 3.0f}, {5.7f, 2.9f},
  153. {6.2f, 2.9f}, {5.1f, 2.5f}, {5.7f, 2.8f}, {6.3f, 3.3f}, {5.8f, 2.7f}, {7.1f, 3.0f},
  154. {6.3f, 2.9f}, {6.5f, 3.0f}, {7.6f, 3.0f}, {4.9f, 2.5f}, {7.3f, 2.9f}, {6.7f, 2.5f},
  155. {7.2f, 3.6f}, {6.5f, 3.2f}, {6.4f, 2.7f}, {6.8f, 3.0f}, {5.7f, 2.5f}, {5.8f, 2.8f},
  156. {6.4f, 3.2f}, {6.5f, 3.0f}, {7.7f, 3.8f}, {7.7f, 2.6f}, {6.0f, 2.2f}, {6.9f, 3.2f},
  157. {5.6f, 2.8f}, {7.7f, 2.8f}, {6.3f, 2.7f}, {6.7f, 3.3f}, {7.2f, 3.2f}, {6.2f, 2.8f},
  158. {6.1f, 3.0f}, {6.4f, 2.8f}, {7.2f, 3.0f}, {7.4f, 2.8f}, {7.9f, 3.8f}, {6.4f, 2.8f},
  159. {6.3f, 2.8f}, {6.1f, 2.6f}, {7.7f, 3.0f}, {6.3f, 3.4f}, {6.4f, 3.1f}, {6.0f, 3.0f},
  160. {6.9f, 3.1f}, {6.7f, 3.1f}, {6.9f, 3.1f}, {5.8f, 2.7f}, {6.8f, 3.2f}, {6.7f, 3.3f},
  161. {6.7f, 3.0f}, {6.3f, 2.5f}, {6.5f, 3.0f}, {6.2f, 3.4f}, {5.9f, 3.0f}, {5.8f, 3.0f}});
  162. y = np.array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  163. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  164. 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  165. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  166. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  167. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  168. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2);
  169. string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/nb_example.npy";
  170. Web.Download(url, "nb", "nb_example.npy");
  171. #endregion
  172. }
  173. public Graph ImportGraph()
  174. {
  175. throw new NotImplementedException();
  176. }
  177. public Graph BuildGraph()
  178. {
  179. throw new NotImplementedException();
  180. }
  181. public void Train(Session sess)
  182. {
  183. throw new NotImplementedException();
  184. }
  185. public void Predict(Session sess)
  186. {
  187. throw new NotImplementedException();
  188. }
  189. public void Test(Session sess)
  190. {
  191. throw new NotImplementedException();
  192. }
  193. }
  194. }