You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NaiveBayesClassifier.cs 4.3 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow;
  5. using NumSharp.Core;
  6. using System.Linq;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// https://github.com/nicolov/naive_bayes_tensorflow
  11. /// </summary>
  12. public class NaiveBayesClassifier : Python, IExample
  13. {
  14. public int Priority => 100;
  15. public bool Enabled => false;
  16. public string Name => "Naive Bayes Classifier";
  17. public Normal dist { get; set; }
  18. public bool Run()
  19. {
  20. np.array(1.0f, 1.0f);
  21. var X = np.array(new float[][] { new float[] { 1.0f, 1.0f }, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, });
  22. var y = np.array(0,0,1,1,2,2);
  23. fit(X, y);
  24. // Create a regular grid and classify each point
  25. return false;
  26. }
  27. public void fit(NDArray X, NDArray y)
  28. {
  29. NDArray unique_y = y.unique<long>();
  30. Dictionary<long, List<List<float>>> dic = new Dictionary<long, List<List<float>>>();
  31. // Init uy in dic
  32. foreach (int uy in unique_y.Data<int>())
  33. {
  34. dic.Add(uy, new List<List<float>>());
  35. }
  36. // Separate training points by class
  37. // Shape : nb_classes * nb_samples * nb_features
  38. int maxCount = 0;
  39. for (int i = 0; i < y.size; i++)
  40. {
  41. long curClass = (long)y[i];
  42. List<List<float>> l = dic[curClass];
  43. List<float> pair = new List<float>();
  44. pair.Add((float)X[i,0]);
  45. pair.Add((float)X[i, 1]);
  46. l.Add(pair);
  47. if (l.Count > maxCount)
  48. {
  49. maxCount = l.Count;
  50. }
  51. dic[curClass] = l;
  52. }
  53. float[,,] points = new float[dic.Count, maxCount, X.shape[1]];
  54. foreach (KeyValuePair<long, List<List<float>>> kv in dic)
  55. {
  56. int j = (int) kv.Key;
  57. for (int i = 0; i < maxCount; i++)
  58. {
  59. for (int k = 0; k < X.shape[1]; k++)
  60. {
  61. points[j, i, k] = kv.Value[i][k];
  62. }
  63. }
  64. }
  65. NDArray points_by_class = np.array<float>(points);
  66. // estimate mean and variance for each class / feature
  67. // shape : nb_classes * nb_features
  68. var cons = tf.constant(points_by_class);
  69. var tup = tf.nn.moments(cons, new int[]{1});
  70. var mean = tup.Item1;
  71. var variance = tup.Item2;
  72. // Create a 3x2 univariate normal distribution with the
  73. // Known mean and variance
  74. var dist = tf.distributions.Normal(mean, tf.sqrt(variance));
  75. this.dist = dist;
  76. }
  77. public Tensor predict (NDArray X)
  78. {
  79. if (dist == null)
  80. {
  81. throw new ArgumentNullException("cant not find the model (normal distribution)!");
  82. }
  83. int nb_classes = (int) dist.scale().shape[0];
  84. int nb_features = (int)dist.scale().shape[1];
  85. // Conditional probabilities log P(x|c) with shape
  86. // (nb_samples, nb_classes)
  87. Tensor tile = tf.tile(new Tensor(X), new Tensor(new int[] { -1, nb_classes, nb_features }));
  88. Tensor r = tf.reshape(tile, new Tensor(new int[] { -1, nb_classes, nb_features }));
  89. var cond_probs = tf.reduce_sum(dist.log_prob(r));
  90. // uniform priors
  91. var priors = np.log(np.array<double>((1.0 / nb_classes) * nb_classes));
  92. // posterior log probability, log P(c) + log P(x|c)
  93. var joint_likelihood = tf.add(new Tensor(priors), cond_probs);
  94. // normalize to get (log)-probabilities
  95. var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, true);
  96. var log_prob = joint_likelihood - norm_factor;
  97. // exp to get the actual probabilities
  98. return tf.exp(log_prob);
  99. }
  100. public void PrepareData()
  101. {
  102. }
  103. }
  104. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。