You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NaiveBayesClassifier.cs 11 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow;
  5. using NumSharp.Core;
  6. using System.Linq;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// https://github.com/nicolov/naive_bayes_tensorflow
  11. /// </summary>
  12. public class NaiveBayesClassifier : Python, IExample
  13. {
  14. public int Priority => 100;
  15. public bool Enabled => true;
  16. public string Name => "Naive Bayes Classifier";
  17. public Normal dist { get; set; }
  18. public bool Run()
  19. {
  20. var X = np.array<double>(new double[][] { new double[] { 5.1, 3.5},new double[] { 4.9, 3.0 },new double[] { 4.7, 3.2 },
  21. new double[] { 4.6, 3.1 },new double[] { 5.0, 3.6 },new double[] { 5.4, 3.9 },
  22. new double[] { 4.6, 3.4 },new double[] { 5.0, 3.4 },new double[] { 4.4, 2.9 },
  23. new double[] { 4.9, 3.1 },new double[] { 5.4, 3.7 },new double[] {4.8, 3.4 },
  24. new double[] {4.8, 3.0 },new double[] {4.3, 3.0 },new double[] {5.8, 4.0 },
  25. new double[] {5.7, 4.4 },new double[] {5.4, 3.9 },new double[] {5.1, 3.5 },
  26. new double[] {5.7, 3.8 },new double[] {5.1, 3.8 },new double[] {5.4, 3.4 },
  27. new double[] {5.1, 3.7 },new double[] {5.1, 3.3 },new double[] {4.8, 3.4 },
  28. new double[] {5.0 , 3.0 },new double[] {5.0 , 3.4 },new double[] {5.2, 3.5 },
  29. new double[] {5.2, 3.4 },new double[] {4.7, 3.2 },new double[] {4.8, 3.1 },
  30. new double[] {5.4, 3.4 },new double[] {5.2, 4.1},new double[] {5.5, 4.2 },
  31. new double[] {4.9, 3.1 },new double[] {5.0 , 3.2 },new double[] {5.5, 3.5 },
  32. new double[] {4.9, 3.6 },new double[] {4.4, 3.0 },new double[] {5.1, 3.4 },
  33. new double[] {5.0 , 3.5 },new double[] {4.5, 2.3 },new double[] {4.4, 3.2 },
  34. new double[] {5.0 , 3.5 },new double[] {5.1, 3.8 },new double[] {4.8, 3.0},
  35. new double[] {5.1, 3.8 },new double[] {4.6, 3.2 },new double[] { 5.3, 3.7 },
  36. new double[] {5.0 , 3.3 },new double[] {7.0 , 3.2 },new double[] {6.4, 3.2 },
  37. new double[] {6.9, 3.1 },new double[] {5.5, 2.3 },new double[] {6.5, 2.8 },
  38. new double[] {5.7, 2.8 },new double[] {6.3, 3.3 },new double[] {4.9, 2.4 },
  39. new double[] {6.6, 2.9 },new double[] {5.2, 2.7 },new double[] {5.0 , 2.0 },
  40. new double[] {5.9, 3.0 },new double[] {6.0 , 2.2 },new double[] {6.1, 2.9 },
  41. new double[] {5.6, 2.9 },new double[] {6.7, 3.1 },new double[] {5.6, 3.0 },
  42. new double[] {5.8, 2.7 },new double[] {6.2, 2.2 },new double[] {5.6, 2.5 },
  43. new double[] {5.9, 3.0},new double[] {6.1, 2.8},new double[] {6.3, 2.5},
  44. new double[] {6.1, 2.8},new double[] {6.4, 2.9},new double[] {6.6, 3.0 },
  45. new double[] {6.8, 2.8},new double[] {6.7, 3.0 },new double[] {6.0 , 2.9},
  46. new double[] {5.7, 2.6},new double[] {5.5, 2.4},new double[] {5.5, 2.4},
  47. new double[] {5.8, 2.7},new double[] {6.0 , 2.7},new double[] {5.4, 3.0 },
  48. new double[] {6.0 , 3.4},new double[] {6.7, 3.1},new double[] {6.3, 2.3},
  49. new double[] {5.6, 3.0 },new double[] {5.5, 2.5},new double[] {5.5, 2.6},
  50. new double[] {6.1, 3.0 },new double[] {5.8, 2.6},new double[] {5.0 , 2.3},
  51. new double[] {5.6, 2.7},new double[] {5.7, 3.0 },new double[] {5.7, 2.9},
  52. new double[] {6.2, 2.9},new double[] {5.1, 2.5},new double[] {5.7, 2.8},
  53. new double[] {6.3, 3.3},new double[] {5.8, 2.7},new double[] {7.1, 3.0 },
  54. new double[] {6.3, 2.9},new double[] {6.5, 3.0 },new double[] {7.6, 3.0 },
  55. new double[] {4.9, 2.5},new double[] {7.3, 2.9},new double[] {6.7, 2.5},
  56. new double[] {7.2, 3.6},new double[] {6.5, 3.2},new double[] {6.4, 2.7},
  57. new double[] {6.8, 3.00 },new double[] {5.7, 2.5},new double[] {5.8, 2.8},
  58. new double[] {6.4, 3.2},new double[] {6.5, 3.0 },new double[] {7.7, 3.8},
  59. new double[] {7.7, 2.6},new double[] {6.0 , 2.2},new double[] {6.9, 3.2},
  60. new double[] {5.6, 2.8},new double[] {7.7, 2.8},new double[] {6.3, 2.7},
  61. new double[] {6.7, 3.3},new double[] {7.2, 3.2},new double[] {6.2, 2.8},
  62. new double[] {6.1, 3.0 },new double[] {6.4, 2.8},new double[] {7.2, 3.0 },
  63. new double[] {7.4, 2.8},new double[] {7.9, 3.8},new double[] {6.4, 2.8},
  64. new double[] {6.3, 2.8},new double[] {6.1, 2.6},new double[] {7.7, 3.0 },
  65. new double[] {6.3, 3.4},new double[] {6.4, 3.1},new double[] {6.0, 3.0},
  66. new double[] {6.9, 3.1},new double[] {6.7, 3.1},new double[] {6.9, 3.1},
  67. new double[] {5.8, 2.7},new double[] {6.8, 3.2},new double[] {6.7, 3.3},
  68. new double[] {6.7, 3.0 },new double[] {6.3, 2.5},new double[] {6.5, 3.0 },
  69. new double[] {6.2, 3.4},new double[] {5.9, 3.0 }, new double[] {5.8, 3.0 }});
  70. var y = np.array<int>(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  71. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  72. 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  73. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  74. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  75. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
  76. 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2);
  77. fit(X, y);
  78. // Create a regular grid and classify each point
  79. double x_min = (double) X.amin(0)[0] - 0.5;
  80. double y_min = (double) X.amin(0)[1] - 0.5;
  81. double x_max = (double) X.amax(0)[0] + 0.5;
  82. double y_max = (double) X.amax(0)[1] + 0.5;
  83. var (xx, yy) = np.meshgrid(np.linspace(x_min, x_max, 30), np.linspace(y_min, y_max, 30));
  84. var s = tf.Session();
  85. var samples = np.vstack(xx.ravel(), yy.ravel());
  86. var Z = s.run(predict(samples));
  87. return true;
  88. }
  89. public void fit(NDArray X, NDArray y)
  90. {
  91. NDArray unique_y = y.unique<long>();
  92. Dictionary<long, List<List<double>>> dic = new Dictionary<long, List<List<double>>>();
  93. // Init uy in dic
  94. foreach (int uy in unique_y.Data<int>())
  95. {
  96. dic.Add(uy, new List<List<double>>());
  97. }
  98. // Separate training points by class
  99. // Shape : nb_classes * nb_samples * nb_features
  100. int maxCount = 0;
  101. for (int i = 0; i < y.size; i++)
  102. {
  103. long curClass = (long)y[i];
  104. List<List<double>> l = dic[curClass];
  105. List<double> pair = new List<double>();
  106. pair.Add((double)X[i,0]);
  107. pair.Add((double)X[i, 1]);
  108. l.Add(pair);
  109. if (l.Count > maxCount)
  110. {
  111. maxCount = l.Count;
  112. }
  113. dic[curClass] = l;
  114. }
  115. double[,,] points = new double[dic.Count, maxCount, X.shape[1]];
  116. foreach (KeyValuePair<long, List<List<double>>> kv in dic)
  117. {
  118. int j = (int) kv.Key;
  119. for (int i = 0; i < maxCount; i++)
  120. {
  121. for (int k = 0; k < X.shape[1]; k++)
  122. {
  123. points[j, i, k] = kv.Value[i][k];
  124. }
  125. }
  126. }
  127. NDArray points_by_class = np.array<double>(points);
  128. // estimate mean and variance for each class / feature
  129. // shape : nb_classes * nb_features
  130. var cons = tf.constant(points_by_class);
  131. var tup = tf.nn.moments(cons, new int[]{1});
  132. var mean = tup.Item1;
  133. var variance = tup.Item2;
  134. // Create a 3x2 univariate normal distribution with the
  135. // Known mean and variance
  136. var dist = tf.distributions.Normal(mean, tf.sqrt(variance));
  137. this.dist = dist;
  138. }
  139. public Tensor predict (NDArray X)
  140. {
  141. if (dist == null)
  142. {
  143. throw new ArgumentNullException("cant not find the model (normal distribution)!");
  144. }
  145. int nb_classes = (int) dist.scale().shape[0];
  146. int nb_features = (int)dist.scale().shape[1];
  147. // Conditional probabilities log P(x|c) with shape
  148. // (nb_samples, nb_classes)
  149. var t1= ops.convert_to_tensor(X, TF_DataType.TF_DOUBLE);
  150. //var t2 = ops.convert_to_tensor(new int[] { 1, nb_classes });
  151. //Tensor tile = tf.tile(t1, t2);
  152. Tensor tile = tf.tile(X, new int[] { 1, nb_classes });
  153. Tensor r = tf.reshape(tile, new Tensor(new int[] { -1, nb_classes, nb_features }));
  154. var cond_probs = tf.reduce_sum(dist.log_prob(r));
  155. // uniform priors
  156. var priors = np.log(np.array<double>((1.0 / nb_classes) * nb_classes));
  157. // posterior log probability, log P(c) + log P(x|c)
  158. var joint_likelihood = tf.add(new Tensor(priors), cond_probs);
  159. // normalize to get (log)-probabilities
  160. var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, true);
  161. var log_prob = joint_likelihood - norm_factor;
  162. // exp to get the actual probabilities
  163. return tf.exp(log_prob);
  164. }
  165. public void PrepareData()
  166. {
  167. }
  168. }
  169. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。