You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NaiveBayesClassifier.cs 2.4 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow;
  5. using NumSharp.Core;
  6. using System.Linq;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// https://github.com/nicolov/naive_bayes_tensorflow
  11. /// </summary>
  12. public class NaiveBayesClassifier : Python, IExample
  13. {
  14. public void Run()
  15. {
  16. // t/f.nn.moments()
  17. }
  18. public void fit(NDArray X, NDArray y)
  19. {
  20. NDArray unique_y = y.unique<long>();
  21. Dictionary<int, List<NDArray>> dic = new Dictionary<int, List<NDArray>>();
  22. // Init uy in dic
  23. foreach (int uy in unique_y.Data<int>())
  24. {
  25. dic.Add(uy, new List<NDArray>());
  26. }
  27. // Separate training points by class
  28. // Shape : nb_classes * nb_samples * nb_features
  29. int maxCount = 0;
  30. foreach (var (x, t) in zip(X.Data<float>(), y.Data<int>()))
  31. {
  32. int curClass = (y[t, 0] as NDArray).Data<int>().First();
  33. List<NDArray> l = dic[curClass];
  34. l.Add(x);
  35. if (l.Count > maxCount)
  36. {
  37. maxCount = l.Count;
  38. }
  39. dic.Add(curClass, l);
  40. }
  41. NDArray points_by_class = np.zeros(dic.Count,maxCount,X.shape[1]);
  42. foreach (KeyValuePair<int, List<NDArray>> kv in dic)
  43. {
  44. var cls = kv.Value.ToArray();
  45. for (int i = 0; i < dic.Count; i++)
  46. {
  47. points_by_class[i] = dic[i];
  48. }
  49. }
  50. // estimate mean and variance for each class / feature
  51. // shape : nb_classes * nb_features
  52. var cons = tf.constant(points_by_class);
  53. Tuple<Tensor, Tensor> tup = tf.nn.moments(cons, new int[]{1});
  54. var mean = tup.Item1;
  55. var variance = tup.Item2;
  56. // Create a 3x2 univariate normal distribution with the
  57. // Known mean and variance
  58. // var dist = tf.distributions.Normal(loc=mean, scale=tf.sqrt(variance));
  59. }
  60. public void predict (NDArray X)
  61. {
  62. // assert self.dist is not None
  63. // nb_classes, nb_features = map(int, self.dist.scale.shape)
  64. throw new NotFiniteNumberException();
  65. }
  66. }
  67. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。