You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LinearRegression.cs 4.5 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. using static Tensorflow.Python;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// A linear regression learning algorithm example using TensorFlow library.
  11. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py
  12. /// </summary>
  13. public class LinearRegression : IExample
  14. {
  15. public int Priority => 3;
  16. public bool Enabled { get; set; } = true;
  17. public string Name => "Linear Regression";
  18. public bool ImportGraph { get; set; } = false;
  19. public int training_epochs = 1000;
  20. // Parameters
  21. float learning_rate = 0.01f;
  22. int display_step = 50;
  23. NumPyRandom rng = np.random;
  24. NDArray train_X, train_Y;
  25. int n_samples;
  26. public bool Run()
  27. {
  28. // Training Data
  29. PrepareData();
  30. // tf Graph Input
  31. var X = tf.placeholder(tf.float32);
  32. var Y = tf.placeholder(tf.float32);
  33. // Set model weights
  34. // We can set a fixed init value in order to debug
  35. // var rnd1 = rng.randn<float>();
  36. // var rnd2 = rng.randn<float>();
  37. var W = tf.Variable(-0.06f, name: "weight");
  38. var b = tf.Variable(-0.73f, name: "bias");
  39. // Construct a linear model
  40. var pred = tf.add(tf.multiply(X, W), b);
  41. // Mean squared error
  42. var cost = tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * n_samples);
  43. // Gradient descent
  44. // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
  45. var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
  46. // Initialize the variables (i.e. assign their default value)
  47. var init = tf.global_variables_initializer();
  48. // Start training
  49. return with(tf.Session(), sess =>
  50. {
  51. // Run the initializer
  52. sess.run(init);
  53. // Fit all training data
  54. for (int epoch = 0; epoch < training_epochs; epoch++)
  55. {
  56. foreach (var (x, y) in zip<float>(train_X, train_Y))
  57. {
  58. sess.run(optimizer,
  59. new FeedItem(X, x),
  60. new FeedItem(Y, y));
  61. }
  62. // Display logs per epoch step
  63. if ((epoch + 1) % display_step == 0)
  64. {
  65. var c = sess.run(cost,
  66. new FeedItem(X, train_X),
  67. new FeedItem(Y, train_Y));
  68. Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}");
  69. }
  70. }
  71. Console.WriteLine("Optimization Finished!");
  72. var training_cost = sess.run(cost,
  73. new FeedItem(X, train_X),
  74. new FeedItem(Y, train_Y));
  75. Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}");
  76. // Testing example
  77. var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f);
  78. var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f);
  79. Console.WriteLine("Testing... (Mean square loss Comparison)");
  80. var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]),
  81. new FeedItem(X, test_X),
  82. new FeedItem(Y, test_Y));
  83. Console.WriteLine($"Testing cost={testing_cost}");
  84. var diff = Math.Abs((float)training_cost - (float)testing_cost);
  85. Console.WriteLine($"Absolute mean square loss difference: {diff}");
  86. return diff < 0.01;
  87. });
  88. }
  89. public void PrepareData()
  90. {
  91. train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
  92. 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f);
  93. train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
  94. 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
  95. n_samples = train_X.shape[0];
  96. }
  97. }
  98. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。