You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BasicLinearModel.cs 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest.Training
  9. {
  10. [TestClass]
  11. public class BasicLinearModel
  12. {
  13. /// <summary>
  14. /// Linear Regression without tf.train.Optimizer
  15. /// https://www.tensorflow.org/tutorials/customization/custom_training
  16. /// </summary>
  17. [TestMethod]
  18. public void LinearRegression()
  19. {
  20. // Initialize the weights to `5.0` and the bias to `0.0`
  21. // In practice, these should be initialized to random values (for example, with `tf.random.normal`)
  22. var W = tf.Variable(5.0f);
  23. var b = tf.Variable(0.0f);
  24. // Define linear model
  25. Func<Tensor, Tensor> model = (x) => W * x + b;
  26. // Define the loss function
  27. Func<Tensor, Tensor, Tensor> loss = (target_y, predicted_y)
  28. => tf.reduce_mean(tf.square(target_y - predicted_y));
  29. int NUM_EXAMPLES = 1000;
  30. float TRUE_W = 3.0f;
  31. float TRUE_b = 2.0f;
  32. var inputs = tf.random.normal(shape: NUM_EXAMPLES);
  33. var noise = tf.random.normal(shape: NUM_EXAMPLES);
  34. var outputs = inputs * TRUE_W + TRUE_b + noise;
  35. Tensor init_loss = loss(model(inputs), outputs);
  36. // print($"Current loss: {init_loss.numpy()}");
  37. // Define a training loop
  38. Func<Tensor, Tensor, float, Tensor> train = (inputs, outputs, learning_rate)
  39. =>
  40. {
  41. using var t = tf.GradientTape();
  42. var current_loss = loss(outputs, model(inputs));
  43. var (dW, db) = t.gradient(current_loss, (W, b));
  44. W.assign_sub(learning_rate * dW);
  45. b.assign_sub(learning_rate * db);
  46. return current_loss;
  47. };
  48. var epochs = range(10);
  49. foreach(var epoch in epochs)
  50. {
  51. var current_loss = train(inputs, outputs, 0.1f);
  52. print($"Epoch {epoch}: W={(float)W.numpy()} b={(float)b.numpy()}, loss={(float)current_loss.numpy()}");
  53. if (epoch > 0) // skip first epoch
  54. Assert.IsTrue((bool)(current_loss < init_loss));
  55. }
  56. }
  57. }
  58. }