You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BasicLinearModel.cs 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.Training
  6. {
  7. [TestClass]
  8. public class BasicLinearModel
  9. {
  10. /// <summary>
  11. /// Linear Regression without tf.train.Optimizer
  12. /// https://www.tensorflow.org/tutorials/customization/custom_training
  13. /// </summary>
  14. [TestMethod]
  15. public void LinearRegression()
  16. {
  17. // Initialize the weights to `5.0` and the bias to `0.0`
  18. // In practice, these should be initialized to random values (for example, with `tf.random.normal`)
  19. var W = tf.Variable(5.0f);
  20. var b = tf.Variable(0.0f);
  21. // Define linear model
  22. Func<Tensor, Tensor> model = (x) => W * x + b;
  23. // Define the loss function
  24. Func<Tensor, Tensor, Tensor> loss = (target_y, predicted_y)
  25. => tf.reduce_mean(tf.square(target_y - predicted_y));
  26. int NUM_EXAMPLES = 1000;
  27. float TRUE_W = 3.0f;
  28. float TRUE_b = 2.0f;
  29. var inputs = tf.random.normal(shape: NUM_EXAMPLES);
  30. var noise = tf.random.normal(shape: NUM_EXAMPLES);
  31. var outputs = inputs * TRUE_W + TRUE_b + noise;
  32. Tensor init_loss = loss(model(inputs), outputs);
  33. // print($"Current loss: {init_loss.numpy()}");
  34. // Define a training loop
  35. Func<Tensor, Tensor, float, Tensor> train = (inputs, outputs, learning_rate)
  36. =>
  37. {
  38. using var t = tf.GradientTape();
  39. var current_loss = loss(outputs, model(inputs));
  40. var (dW, db) = t.gradient(current_loss, (W, b));
  41. W.assign_sub(learning_rate * dW);
  42. b.assign_sub(learning_rate * db);
  43. return current_loss;
  44. };
  45. var epochs = range(10);
  46. foreach (var epoch in epochs)
  47. {
  48. var current_loss = train(inputs, outputs, 0.1f);
  49. print($"Epoch {epoch}: W={(float)W.numpy()} b={(float)b.numpy()}, loss={(float)current_loss.numpy()}");
  50. if (epoch > 0) // skip first epoch
  51. Assert.IsTrue((bool)(current_loss < init_loss));
  52. }
  53. }
  54. }
  55. }