You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BasicLinearModel.cs 2.5 kB

1 year ago
1 year ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.Training
  6. {
  7. [TestClass]
  8. public class BasicLinearModel
  9. {
  10. /// <summary>
  11. /// Linear Regression without tf.train.Optimizer
  12. /// https://www.tensorflow.org/tutorials/customization/custom_training
  13. /// </summary>
  14. [TestMethod]
  15. public void LinearRegression()
  16. {
  17. var graph = tf.Graph().as_default();
  18. var sess = new Session(graph);
  19. sess.as_default();
  20. // Initialize the weights to `5.0` and the bias to `0.0`
  21. // In practice, these should be initialized to random values (for example, with `tf.random.normal`)
  22. var W = tf.Variable(5.0f);
  23. var b = tf.Variable(0.0f);
  24. // Define linear model
  25. Func<Tensor, Tensor> model = (x) => W * x + b;
  26. // Define the loss function
  27. Func<Tensor, Tensor, Tensor> loss = (target_y, predicted_y)
  28. => tf.reduce_mean(tf.square(target_y - predicted_y));
  29. int NUM_EXAMPLES = 1000;
  30. float TRUE_W = 3.0f;
  31. float TRUE_b = 2.0f;
  32. var inputs = tf.random.normal(shape: NUM_EXAMPLES);
  33. var noise = tf.random.normal(shape: NUM_EXAMPLES);
  34. var outputs = inputs * TRUE_W + TRUE_b + noise;
  35. Tensor init_loss = loss(model(inputs), outputs);
  36. // print($"Current loss: {init_loss.numpy()}");
  37. // Define a training loop
  38. Func<Tensor, Tensor, float, Tensor> train = (inputs, outputs, learning_rate)
  39. =>
  40. {
  41. using var t = tf.GradientTape();
  42. var current_loss = loss(outputs, model(inputs));
  43. var (dW, db) = t.gradient(current_loss, (W, b));
  44. W.assign_sub(learning_rate * dW);
  45. b.assign_sub(learning_rate * db);
  46. return current_loss;
  47. };
  48. var epochs = range(10);
  49. foreach (var epoch in epochs)
  50. {
  51. var current_loss = train(inputs, outputs, 0.1f);
  52. print($"Epoch {epoch}: W={(float)W.numpy()} b={(float)b.numpy()}, loss={(float)current_loss.numpy()}");
  53. if (epoch > 0) // skip first epoch
  54. Assert.IsTrue((bool)(current_loss < init_loss));
  55. }
  56. }
  57. }
  58. }