You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LinearRegression.cs 4.3 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. using Newtonsoft.Json;
  2. using NumSharp.Core;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using Tensorflow;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// A linear regression learning algorithm example using TensorFlow library.
  11. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py
  12. /// </summary>
  13. public class LinearRegression : Python, IExample
  14. {
  15. private NumPyRandom rng = np.random;
  16. public void Run()
  17. {
  18. var graph = tf.Graph().as_default();
  19. // Parameters
  20. float learning_rate = 0.01f;
  21. int training_epochs = 1000;
  22. int display_step = 10;
  23. // Training Data
  24. var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
  25. 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f);
  26. var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
  27. 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
  28. var n_samples = train_X.shape[0];
  29. // tf Graph Input
  30. var X = tf.placeholder(tf.float32);
  31. var Y = tf.placeholder(tf.float32);
  32. // Set model weights
  33. //var rnd1 = rng.randn<float>();
  34. //var rnd2 = rng.randn<float>();
  35. var W = tf.Variable(-0.06f, name: "weight");
  36. var b = tf.Variable(-0.73f, name: "bias");
  37. var mul = tf.multiply(X, W);
  38. var pred = tf.add(mul, b);
  39. // Mean squared error
  40. var sub = pred - Y;
  41. var pow = tf.pow(sub, 2.0f);
  42. var reduce = tf.reduce_sum(pow);
  43. var cost = reduce / (2.0f * n_samples);
  44. // radient descent
  45. // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
  46. var grad = tf.train.GradientDescentOptimizer(learning_rate);
  47. var optimizer = grad.minimize(cost);
  48. //tf.train.export_meta_graph(filename: "linear_regression.meta.bin");
  49. // import meta
  50. // var new_saver = tf.train.import_meta_graph("linear_regression.meta.bin");
  51. var text = JsonConvert.SerializeObject(graph, new JsonSerializerSettings
  52. {
  53. Formatting = Formatting.Indented
  54. });
  55. /*var cost = graph.OperationByName("truediv").output;
  56. var pred = graph.OperationByName("Add").output;
  57. var optimizer = graph.OperationByName("GradientDescent");
  58. var X = graph.OperationByName("Placeholder").output;
  59. var Y = graph.OperationByName("Placeholder_1").output;
  60. var W = graph.OperationByName("weight").output;
  61. var b = graph.OperationByName("bias").output;*/
  62. // Initialize the variables (i.e. assign their default value)
  63. var init = tf.global_variables_initializer();
  64. // Start training
  65. with<Session>(tf.Session(graph), sess =>
  66. {
  67. // Run the initializer
  68. sess.run(init);
  69. // Fit all training data
  70. for (int epoch = 0; epoch < training_epochs; epoch++)
  71. {
  72. foreach (var (x, y) in zip<float>(train_X, train_Y))
  73. {
  74. sess.run(optimizer,
  75. new FeedItem(X, x),
  76. new FeedItem(Y, y));
  77. var rW = sess.run(W);
  78. }
  79. // Display logs per epoch step
  80. /*if ((epoch + 1) % display_step == 0)
  81. {
  82. var c = sess.run(cost,
  83. new FeedItem(X, train_X),
  84. new FeedItem(Y, train_Y));
  85. var rW = sess.run(W);
  86. Console.WriteLine($"Epoch: {epoch + 1} cost={c} " +
  87. $"W={rW} b={sess.run(b)}");
  88. }*/
  89. }
  90. Console.WriteLine("Optimization Finished!");
  91. });
  92. }
  93. }
  94. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。