You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BasicLinearModel.cs 2.4 kB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.Training
  6. {
  7. [TestClass]
  8. public class BasicLinearModel
  9. {
  10. /// <summary>
  11. /// Linear Regression without tf.train.Optimizer
  12. /// https://www.tensorflow.org/tutorials/customization/custom_training
  13. /// </summary>
  14. [TestMethod]
  15. public void LinearRegression()
  16. {
  17. tf.Graph().as_default();
  18. // Initialize the weights to `5.0` and the bias to `0.0`
  19. // In practice, these should be initialized to random values (for example, with `tf.random.normal`)
  20. var W = tf.Variable(5.0f);
  21. var b = tf.Variable(0.0f);
  22. // Define linear model
  23. Func<Tensor, Tensor> model = (x) => W * x + b;
  24. // Define the loss function
  25. Func<Tensor, Tensor, Tensor> loss = (target_y, predicted_y)
  26. => tf.reduce_mean(tf.square(target_y - predicted_y));
  27. int NUM_EXAMPLES = 1000;
  28. float TRUE_W = 3.0f;
  29. float TRUE_b = 2.0f;
  30. var inputs = tf.random.normal(shape: NUM_EXAMPLES);
  31. var noise = tf.random.normal(shape: NUM_EXAMPLES);
  32. var outputs = inputs * TRUE_W + TRUE_b + noise;
  33. Tensor init_loss = loss(model(inputs), outputs);
  34. // print($"Current loss: {init_loss.numpy()}");
  35. // Define a training loop
  36. Func<Tensor, Tensor, float, Tensor> train = (inputs, outputs, learning_rate)
  37. =>
  38. {
  39. using var t = tf.GradientTape();
  40. var current_loss = loss(outputs, model(inputs));
  41. var (dW, db) = t.gradient(current_loss, (W, b));
  42. W.assign_sub(learning_rate * dW);
  43. b.assign_sub(learning_rate * db);
  44. return current_loss;
  45. };
  46. var epochs = range(10);
  47. foreach (var epoch in epochs)
  48. {
  49. var current_loss = train(inputs, outputs, 0.1f);
  50. print($"Epoch {epoch}: W={(float)W.numpy()} b={(float)b.numpy()}, loss={(float)current_loss.numpy()}");
  51. if (epoch > 0) // skip first epoch
  52. Assert.IsTrue((bool)(current_loss < init_loss));
  53. }
  54. }
  55. }
  56. }