You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Keras.cs 2.6 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. using System;
  2. using System.Collections.Generic;
  3. using Tensorflow;
  4. using Keras.Layers;
  5. using NumSharp;
  6. using Keras;
  7. namespace TensorFlowNET.Examples
  8. {
  9. public class Keras : IExample
  10. {
  11. public bool Enabled { get; set; } = true;
  12. public bool IsImportingGraph { get; set; } = false;
  13. public string Name => "Keras";
  14. public bool Run()
  15. {
  16. Console.WriteLine("================================== Keras ==================================");
  17. #region data
  18. var batch_size = 1000;
  19. var (X, Y) = XOR(batch_size);
  20. //var (X, Y, batch_size) = (np.array(new float[,]{{1, 0 },{1, 1 },{0, 0 },{0, 1 }}), np.array(new int[] { 0, 1, 1, 0 }), 4);
  21. #endregion
  22. #region features
  23. var (features, labels) = (new Tensor(X), new Tensor(Y));
  24. var num_steps = 10000;
  25. #endregion
  26. #region model
  27. var m = new Model();
  28. //m.Add(new Dense(8, name: "Hidden", activation: tf.nn.relu())).Add(new Dense(1, name:"Output"));
  29. m.Add(
  30. new ILayer[] {
  31. new Dense(8, name: "Hidden_1", activation: tf.nn.relu()),
  32. new Dense(1, name: "Output")
  33. });
  34. m.train(num_steps, (X, Y));
  35. #endregion
  36. return true;
  37. }
  38. static (NDArray, NDArray) XOR(int samples)
  39. {
  40. var X = new List<float[]>();
  41. var Y = new List<float>();
  42. var r = new Random();
  43. for (int i = 0; i < samples; i++)
  44. {
  45. var x1 = (float)r.Next(0, 2);
  46. var x2 = (float)r.Next(0, 2);
  47. var y = 0.0f;
  48. if (x1 == x2)
  49. y = 1.0f;
  50. X.Add(new float[] { x1, x2 });
  51. Y.Add(y);
  52. }
  53. return (np.array(X.ToArray()), np.array(Y.ToArray()));
  54. }
  55. public Graph BuildGraph()
  56. {
  57. throw new NotImplementedException();
  58. }
  59. public Graph ImportGraph()
  60. {
  61. throw new NotImplementedException();
  62. }
  63. public void Predict(Session sess)
  64. {
  65. throw new NotImplementedException();
  66. }
  67. public void PrepareData()
  68. {
  69. throw new NotImplementedException();
  70. }
  71. public void Test(Session sess)
  72. {
  73. throw new NotImplementedException();
  74. }
  75. public void Train(Session sess)
  76. {
  77. throw new NotImplementedException();
  78. }
  79. }
  80. }