You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientTest.cs 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.ManagedAPI
  6. {
  7. [TestClass]
  8. public class GradientTest
  9. {
  10. [TestMethod]
  11. public void GradientFloatTest()
  12. {
  13. var x = tf.Variable(3.0, dtype: tf.float32);
  14. using var tape = tf.GradientTape();
  15. var y = tf.square(x);
  16. var y_grad = tape.gradient(y, x);
  17. Assert.AreEqual(9.0f, (float)y);
  18. }
  19. [TestMethod]
  20. public void GradientDefaultTest()
  21. {
  22. var x = tf.Variable(3.0);
  23. using var tape = tf.GradientTape();
  24. var y = tf.square(x);
  25. var y_grad = tape.gradient(y, x);
  26. Assert.AreEqual(9.0, (double)y);
  27. }
  28. [TestMethod]
  29. public void GradientDoubleTest()
  30. {
  31. var x = tf.Variable(3.0, dtype: tf.float64);
  32. using var tape = tf.GradientTape();
  33. var y = tf.square(x);
  34. var y_grad = tape.gradient(y, x);
  35. Assert.AreEqual(9.0, (double)y);
  36. }
  37. [TestMethod]
  38. public void GradientOperatorMulTest()
  39. {
  40. var x = tf.constant(0f);
  41. var w = tf.Variable(new float[] { 1, 1 });
  42. using var gt = tf.GradientTape();
  43. var y = x * w;
  44. var gr = gt.gradient(y, w);
  45. Assert.AreEqual(new float[] { 0, 0 }, gr.numpy());
  46. }
  47. [TestMethod]
  48. public void GradientSliceTest()
  49. {
  50. var X = tf.zeros(10);
  51. var W = tf.Variable(-0.06f, name: "weight");
  52. var b = tf.Variable(-0.73f, name: "bias");
  53. using var g = tf.GradientTape();
  54. var pred = W * X + b;
  55. var test = tf.slice(pred, new[] { 0 }, (int[])pred.shape);
  56. var gradients = g.gradient(test, (W, b));
  57. Assert.AreEqual((float)gradients.Item1, 0f);
  58. Assert.AreEqual((float)gradients.Item2, 10f);
  59. }
  60. [TestMethod]
  61. public void GradientConcatTest()
  62. {
  63. var w1 = tf.Variable(new[] { new[] { 1f } });
  64. var w2 = tf.Variable(new[] { new[] { 3f } });
  65. using var g = tf.GradientTape();
  66. var w = tf.concat(new Tensor[] { w1, w2 }, 0);
  67. var x = tf.ones((1, 2));
  68. var y = tf.reduce_sum(x, 1);
  69. var r = tf.matmul(w, x);
  70. var gradients = g.gradient(r, w);
  71. Assert.AreEqual((float)gradients[0][0], 2f);
  72. Assert.AreEqual((float)gradients[1][0], 2f);
  73. }
  74. }
  75. }