You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LinearRegression.cs 5.8 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using NumSharp;
  14. using System;
  15. using Tensorflow;
  16. using static Tensorflow.Python;
  17. namespace TensorFlowNET.Examples
  18. {
  19. /// <summary>
  20. /// A linear regression learning algorithm example using TensorFlow library.
  21. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py
  22. /// </summary>
  23. public class LinearRegression : IExample
  24. {
  25. public bool Enabled { get; set; } = true;
  26. public string Name => "Linear Regression";
  27. public bool IsImportingGraph { get; set; } = false;
  28. public int training_epochs = 1000;
  29. // Parameters
  30. float learning_rate = 0.01f;
  31. int display_step = 50;
  32. NumPyRandom rng = np.random;
  33. NDArray train_X, train_Y;
  34. int n_samples;
  35. public bool Run()
  36. {
  37. // Training Data
  38. PrepareData();
  39. // tf Graph Input
  40. var X = tf.placeholder(tf.float32);
  41. var Y = tf.placeholder(tf.float32);
  42. // Set model weights
  43. // We can set a fixed init value in order to debug
  44. // var rnd1 = rng.randn<float>();
  45. // var rnd2 = rng.randn<float>();
  46. var W = tf.Variable(-0.06f, name: "weight");
  47. var b = tf.Variable(-0.73f, name: "bias");
  48. // Construct a linear model
  49. var pred = tf.add(tf.multiply(X, W), b);
  50. // Mean squared error
  51. var cost = tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * n_samples);
  52. // Gradient descent
  53. // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
  54. var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
  55. // Initialize the variables (i.e. assign their default value)
  56. var init = tf.global_variables_initializer();
  57. // Start training
  58. return with(tf.Session(), sess =>
  59. {
  60. // Run the initializer
  61. sess.run(init);
  62. // Fit all training data
  63. for (int epoch = 0; epoch < training_epochs; epoch++)
  64. {
  65. foreach (var (x, y) in zip<float>(train_X, train_Y))
  66. {
  67. sess.run(optimizer,
  68. new FeedItem(X, x),
  69. new FeedItem(Y, y));
  70. }
  71. // Display logs per epoch step
  72. if ((epoch + 1) % display_step == 0)
  73. {
  74. var c = sess.run(cost,
  75. new FeedItem(X, train_X),
  76. new FeedItem(Y, train_Y));
  77. Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}");
  78. }
  79. }
  80. Console.WriteLine("Optimization Finished!");
  81. var training_cost = sess.run(cost,
  82. new FeedItem(X, train_X),
  83. new FeedItem(Y, train_Y));
  84. Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}");
  85. // Testing example
  86. var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f);
  87. var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f);
  88. Console.WriteLine("Testing... (Mean square loss Comparison)");
  89. var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]),
  90. new FeedItem(X, test_X),
  91. new FeedItem(Y, test_Y));
  92. Console.WriteLine($"Testing cost={testing_cost}");
  93. var diff = Math.Abs((float)training_cost - (float)testing_cost);
  94. Console.WriteLine($"Absolute mean square loss difference: {diff}");
  95. return diff < 0.01;
  96. });
  97. }
  98. public void PrepareData()
  99. {
  100. train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
  101. 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f);
  102. train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
  103. 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
  104. n_samples = train_X.shape[0];
  105. }
  106. public Graph ImportGraph()
  107. {
  108. throw new NotImplementedException();
  109. }
  110. public Graph BuildGraph()
  111. {
  112. throw new NotImplementedException();
  113. }
  114. public void Train(Session sess)
  115. {
  116. throw new NotImplementedException();
  117. }
  118. public void Predict(Session sess)
  119. {
  120. throw new NotImplementedException();
  121. }
  122. public void Test(Session sess)
  123. {
  124. throw new NotImplementedException();
  125. }
  126. }
  127. }