You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NaiveBayesClassifier.cs 11 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow;
  5. using NumSharp;
  6. using System.Linq;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// https://github.com/nicolov/naive_bayes_tensorflow
  11. /// </summary>
  12. public class NaiveBayesClassifier : Python, IExample
  13. {
  14. public int Priority => 6;
  15. public bool Enabled { get; set; } = true;
  16. public string Name => "Naive Bayes Classifier";
  17. public bool ImportGraph { get; set; } = false;
  18. public Normal dist { get; set; }
  19. public bool Run()
  20. {
  21. var X = np.array(new float[][] { new float[] { 5.1f, 3.5f},new float[] { 4.9f, 3.0f },new float[] { 4.7f, 3.2f },
  22. new float[] { 4.6f, 3.1f },new float[] { 5.0f, 3.6f },new float[] { 5.4f, 3.9f },
  23. new float[] { 4.6f, 3.4f },new float[] { 5.0f, 3.4f },new float[] { 4.4f, 2.9f },
  24. new float[] { 4.9f, 3.1f },new float[] { 5.4f, 3.7f },new float[] {4.8f, 3.4f },
  25. new float[] {4.8f, 3.0f },new float[] {4.3f, 3.0f },new float[] {5.8f, 4.0f },
  26. new float[] {5.7f, 4.4f },new float[] {5.4f, 3.9f },new float[] {5.1f, 3.5f },
  27. new float[] {5.7f, 3.8f },new float[] {5.1f, 3.8f },new float[] {5.4f, 3.4f },
  28. new float[] {5.1f, 3.7f },new float[] {5.1f, 3.3f },new float[] {4.8f, 3.4f },
  29. new float[] {5.0f, 3.0f },new float[] {5.0f , 3.4f },new float[] {5.2f, 3.5f },
  30. new float[] {5.2f, 3.4f },new float[] {4.7f, 3.2f },new float[] {4.8f, 3.1f },
  31. new float[] {5.4f, 3.4f },new float[] {5.2f, 4.1f},new float[] {5.5f, 4.2f },
  32. new float[] {4.9f, 3.1f },new float[] {5.0f , 3.2f },new float[] {5.5f, 3.5f },
  33. new float[] {4.9f, 3.6f },new float[] {4.4f, 3.0f },new float[] {5.1f, 3.4f },
  34. new float[] {5.0f , 3.5f },new float[] {4.5f, 2.3f },new float[] {4.4f, 3.2f },
  35. new float[] {5.0f , 3.5f },new float[] {5.1f, 3.8f },new float[] {4.8f, 3.0f},
  36. new float[] {5.1f, 3.8f },new float[] {4.6f, 3.2f },new float[] { 5.3f, 3.7f },
  37. new float[] {5.0f , 3.3f },new float[] {7.0f , 3.2f },new float[] {6.4f, 3.2f },
  38. new float[] {6.9f, 3.1f },new float[] {5.5f, 2.3f },new float[] {6.5f, 2.8f },
  39. new float[] {5.7f, 2.8f },new float[] {6.3f, 3.3f },new float[] {4.9f, 2.4f },
  40. new float[] {6.6f, 2.9f },new float[] {5.2f, 2.7f },new float[] {5.0f , 2.0f },
  41. new float[] {5.9f, 3.0f },new float[] {6.0f , 2.2f },new float[] {6.1f, 2.9f },
  42. new float[] {5.6f, 2.9f },new float[] {6.7f, 3.1f },new float[] {5.6f, 3.0f },
  43. new float[] {5.8f, 2.7f },new float[] {6.2f, 2.2f },new float[] {5.6f, 2.5f },
  44. new float[] {5.9f, 3.0f},new float[] {6.1f, 2.8f},new float[] {6.3f, 2.5f},
  45. new float[] {6.1f, 2.8f},new float[] {6.4f, 2.9f},new float[] {6.6f, 3.0f },
  46. new float[] {6.8f, 2.8f},new float[] {6.7f, 3.0f },new float[] {6.0f , 2.9f},
  47. new float[] {5.7f, 2.6f},new float[] {5.5f, 2.4f},new float[] {5.5f, 2.4f},
  48. new float[] {5.8f, 2.7f},new float[] {6.0f , 2.7f},new float[] {5.4f, 3.0f},
  49. new float[] {6.0f , 3.4f},new float[] {6.7f, 3.1f},new float[] {6.3f, 2.3f},
  50. new float[] {5.6f, 3.0f },new float[] {5.5f, 2.5f},new float[] {5.5f, 2.6f},
  51. new float[] {6.1f, 3.0f },new float[] {5.8f, 2.6f},new float[] {5.0f, 2.3f},
  52. new float[] {5.6f, 2.7f},new float[] {5.7f, 3.0f },new float[] {5.7f, 2.9f},
  53. new float[] {6.2f, 2.9f},new float[] {5.1f, 2.5f},new float[] {5.7f, 2.8f},
  54. new float[] {6.3f, 3.3f},new float[] {5.8f, 2.7f},new float[] {7.1f, 3.0f },
  55. new float[] {6.3f, 2.9f},new float[] {6.5f, 3.0f },new float[] {7.6f, 3.0f },
  56. new float[] {4.9f, 2.5f},new float[] {7.3f, 2.9f},new float[] {6.7f, 2.5f},
  57. new float[] {7.2f, 3.6f},new float[] {6.5f, 3.2f},new float[] {6.4f, 2.7f},
  58. new float[] {6.8f, 3.00f },new float[] {5.7f, 2.5f},new float[] {5.8f, 2.8f},
  59. new float[] {6.4f, 3.2f},new float[] {6.5f, 3.0f },new float[] {7.7f, 3.8f},
  60. new float[] {7.7f, 2.6f},new float[] {6.0f , 2.2f},new float[] {6.9f, 3.2f},
  61. new float[] {5.6f, 2.8f},new float[] {7.7f, 2.8f},new float[] {6.3f, 2.7f},
  62. new float[] {6.7f, 3.3f},new float[] {7.2f, 3.2f},new float[] {6.2f, 2.8f},
  63. new float[] {6.1f, 3.0f },new float[] {6.4f, 2.8f},new float[] {7.2f, 3.0f },
  64. new float[] {7.4f, 2.8f},new float[] {7.9f, 3.8f},new float[] {6.4f, 2.8f},
  65. new float[] {6.3f, 2.8f},new float[] {6.1f, 2.6f},new float[] {7.7f, 3.0f },
  66. new float[] {6.3f, 3.4f},new float[] {6.4f, 3.1f},new float[] {6.0f, 3.0f},
  67. new float[] {6.9f, 3.1f},new float[] {6.7f, 3.1f},new float[] {6.9f, 3.1f},
  68. new float[] {5.8f, 2.7f},new float[] {6.8f, 3.2f},new float[] {6.7f, 3.3f},
  69. new float[] {6.7f, 3.0f },new float[] {6.3f, 2.5f},new float[] {6.5f, 3.0f },
  70. new float[] {6.2f, 3.4f},new float[] {5.9f, 3.0f }, new float[] {5.8f, 3.0f }});
  71. var y = np.array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  72. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  73. 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  74. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  75. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  76. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  77. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2);
  78. fit(X, y);
  79. // Create a regular grid and classify each point
  80. float x_min = X.amin(0).Data<float>(0) - 0.5f;
  81. float y_min = X.amin(0).Data<float>(1) - 0.5f;
  82. float x_max = X.amax(0).Data<float>(0) + 0.5f;
  83. float y_max = X.amax(0).Data<float>(1) + 0.5f;
  84. var (xx, yy) = np.meshgrid(np.linspace(x_min, x_max, 30), np.linspace(y_min, y_max, 30));
  85. var s = tf.Session();
  86. if (xx.dtype == typeof(float))
  87. {
  88. var samples = np.hstack<float>(xx.ravel().reshape(xx.size,1), yy.ravel().reshape(yy.size,1));
  89. var Z = s.run(predict(samples));
  90. }
  91. return true;
  92. }
  93. public void fit(NDArray X, NDArray y)
  94. {
  95. var unique_y = y.unique<int>();
  96. var dic = new Dictionary<int, List<List<float>>>();
  97. // Init uy in dic
  98. foreach (int uy in unique_y.Data<int>())
  99. {
  100. dic.Add(uy, new List<List<float>>());
  101. }
  102. // Separate training points by class
  103. // Shape : nb_classes * nb_samples * nb_features
  104. int maxCount = 0;
  105. for (int i = 0; i < y.size; i++)
  106. {
  107. var curClass = y[i];
  108. var l = dic[curClass];
  109. var pair = new List<float>();
  110. pair.Add(X[i,0]);
  111. pair.Add(X[i, 1]);
  112. l.Add(pair);
  113. if (l.Count > maxCount)
  114. {
  115. maxCount = l.Count;
  116. }
  117. dic[curClass] = l;
  118. }
  119. float[,,] points = new float[dic.Count, maxCount, X.shape[1]];
  120. foreach (KeyValuePair<int, List<List<float>>> kv in dic)
  121. {
  122. int j = (int) kv.Key;
  123. for (int i = 0; i < maxCount; i++)
  124. {
  125. for (int k = 0; k < X.shape[1]; k++)
  126. {
  127. points[j, i, k] = kv.Value[i][k];
  128. }
  129. }
  130. }
  131. var points_by_class = np.array(points);
  132. // estimate mean and variance for each class / feature
  133. // shape : nb_classes * nb_features
  134. var cons = tf.constant(points_by_class);
  135. var tup = tf.nn.moments(cons, new int[]{1});
  136. var mean = tup.Item1;
  137. var variance = tup.Item2;
  138. // Create a 3x2 univariate normal distribution with the
  139. // Known mean and variance
  140. var dist = tf.distributions.Normal(mean, tf.sqrt(variance));
  141. this.dist = dist;
  142. }
  143. public Tensor predict (NDArray X)
  144. {
  145. if (dist == null)
  146. {
  147. throw new ArgumentNullException("cant not find the model (normal distribution)!");
  148. }
  149. int nb_classes = (int) dist.scale().shape[0];
  150. int nb_features = (int)dist.scale().shape[1];
  151. // Conditional probabilities log P(x|c) with shape
  152. // (nb_samples, nb_classes)
  153. var t1= ops.convert_to_tensor(X, TF_DataType.TF_FLOAT);
  154. var t2 = ops.convert_to_tensor(new int[] { 1, nb_classes });
  155. Tensor tile = tf.tile(t1, t2);
  156. var t3 = ops.convert_to_tensor(new int[] { -1, nb_classes, nb_features });
  157. Tensor r = tf.reshape(tile, t3);
  158. var cond_probs = tf.reduce_sum(dist.log_prob(r), 2);
  159. // uniform priors
  160. float[] tem = new float[nb_classes];
  161. for (int i = 0; i < tem.Length; i++)
  162. {
  163. tem[i] = 1.0f / nb_classes;
  164. }
  165. var priors = np.log(np.array<float>(tem));
  166. // posterior log probability, log P(c) + log P(x|c)
  167. var joint_likelihood = tf.add(ops.convert_to_tensor(priors, TF_DataType.TF_FLOAT), cond_probs);
  168. // normalize to get (log)-probabilities
  169. var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, keepdims: true);
  170. var log_prob = joint_likelihood - norm_factor;
  171. // exp to get the actual probabilities
  172. return tf.exp(log_prob);
  173. }
  174. public void PrepareData()
  175. {
  176. }
  177. }
  178. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。