You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BasicLinearModel.cs 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest.Training
  9. {
  10. [TestClass]
  11. public class BasicLinearModel
  12. {
  13. /// <summary>
  14. /// Linear Regression without tf.train.Optimizer
  15. /// https://www.tensorflow.org/tutorials/customization/custom_training
  16. /// </summary>
  17. [TestMethod]
  18. public void LinearRegression()
  19. {
  20. // Initialize the weights to `5.0` and the bias to `0.0`
  21. // In practice, these should be initialized to random values (for example, with `tf.random.normal`)
  22. var W = tf.Variable(5.0f);
  23. var b = tf.Variable(0.0f);
  24. // Define linear model
  25. Func<Tensor, Tensor> model = (x) => W * x + b;
  26. // Define the loss function
  27. Func<Tensor, Tensor, Tensor> loss = (target_y, predicted_y)
  28. => tf.reduce_mean(tf.square(target_y - predicted_y));
  29. int NUM_EXAMPLES = 1000;
  30. float TRUE_W = 3.0f;
  31. float TRUE_b = 2.0f;
  32. var inputs = tf.random.normal(shape: NUM_EXAMPLES);
  33. var noise = tf.random.normal(shape: NUM_EXAMPLES);
  34. var outputs = inputs * TRUE_W + TRUE_b + noise;
  35. print($"Current loss: {loss(model(inputs), outputs).numpy()}");
  36. // Define a training loop
  37. Action<Tensor, Tensor, float> train = (inputs, outputs, learning_rate)
  38. =>
  39. {
  40. using var t = tf.GradientTape();
  41. var current_loss = loss(outputs, model(inputs));
  42. var (dW, db) = t.gradient(current_loss, (W, b));
  43. W.assign_sub(learning_rate * dW);
  44. b.assign_sub(learning_rate * db);
  45. };
  46. var epochs = range(10);
  47. foreach(var epoch in epochs)
  48. {
  49. train(inputs, outputs, 0.1f);
  50. print($"Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f");
  51. }
  52. }
  53. }
  54. }