You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Keras.cs 2.7 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. using System;
  2. using System.Collections.Generic;
  3. using Tensorflow;
  4. using Keras.Layers;
  5. using NumSharp;
  6. using Keras;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.Examples
  9. {
  10. public class Keras : IExample
  11. {
  12. public bool Enabled { get; set; } = true;
  13. public bool IsImportingGraph { get; set; } = false;
  14. public string Name => "Keras";
  15. public bool Run()
  16. {
  17. Console.WriteLine("================================== Keras ==================================");
  18. #region data
  19. var batch_size = 1000;
  20. var (X, Y) = XOR(batch_size);
  21. //var (X, Y, batch_size) = (np.array(new float[,]{{1, 0 },{1, 1 },{0, 0 },{0, 1 }}), np.array(new int[] { 0, 1, 1, 0 }), 4);
  22. #endregion
  23. #region features
  24. var (features, labels) = (new Tensor(X), new Tensor(Y));
  25. var num_steps = 10000;
  26. #endregion
  27. #region model
  28. var m = new Model();
  29. //m.Add(new Dense(8, name: "Hidden", activation: tf.nn.relu())).Add(new Dense(1, name:"Output"));
  30. m.Add(
  31. new ILayer[] {
  32. new Dense(8, name: "Hidden_1", activation: tf.nn.relu()),
  33. new Dense(1, name: "Output")
  34. });
  35. m.train(num_steps, (X, Y));
  36. #endregion
  37. return true;
  38. }
  39. static (NDArray, NDArray) XOR(int samples)
  40. {
  41. var X = new List<float[]>();
  42. var Y = new List<float>();
  43. var r = new Random();
  44. for (int i = 0; i < samples; i++)
  45. {
  46. var x1 = (float)r.Next(0, 2);
  47. var x2 = (float)r.Next(0, 2);
  48. var y = 0.0f;
  49. if (x1 == x2)
  50. y = 1.0f;
  51. X.Add(new float[] { x1, x2 });
  52. Y.Add(y);
  53. }
  54. return (np.array(X.ToArray()), np.array(Y.ToArray()));
  55. }
  56. public Graph BuildGraph()
  57. {
  58. throw new NotImplementedException();
  59. }
  60. public Graph ImportGraph()
  61. {
  62. throw new NotImplementedException();
  63. }
  64. public void Predict(Session sess)
  65. {
  66. throw new NotImplementedException();
  67. }
  68. public void PrepareData()
  69. {
  70. throw new NotImplementedException();
  71. }
  72. public void Test(Session sess)
  73. {
  74. throw new NotImplementedException();
  75. }
  76. public void Train(Session sess)
  77. {
  78. throw new NotImplementedException();
  79. }
  80. }
  81. }