You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientEagerTest.cs 5.6 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow;
  6. using Tensorflow.NumPy;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest.Gradient
  9. {
  10. [TestClass]
  11. public class GradientEagerTest : EagerModeTestBase
  12. {
  13. [TestMethod]
  14. public void ConstantSquare()
  15. {
  16. // Calcute the gradient of w * w
  17. // by Automatic Differentiation in Eager mode
  18. var w = tf.constant(1.5f);
  19. using var tape = tf.GradientTape();
  20. // w is defined before tape is recording
  21. tape.watch(w);
  22. var loss = w * w;
  23. var grad = tape.gradient(loss, w);
  24. Assert.AreEqual((float)grad, 3.0f);
  25. }
  26. [TestMethod]
  27. public void SquaredDifference_Constant()
  28. {
  29. // Calcute the gradient of (x1-x2)^2
  30. // by Automatic Differentiation in Eager mode
  31. var x1 = tf.constant(7f);
  32. var x2 = tf.constant(11f);
  33. // Sanity check
  34. using (var tape = tf.GradientTape())
  35. {
  36. tape.watch(x2);
  37. var loss = tf.multiply((x1 - x2), (x1 - x2));
  38. var result = tape.gradient(loss, x2);
  39. // Expected is 2*(11-7) = 8
  40. Assert.AreEqual((float)result, 8f);
  41. }
  42. // Actual test
  43. using (var tape = tf.GradientTape())
  44. {
  45. tape.watch(x2);
  46. var loss = tf.squared_difference(x1, x2);
  47. // Expected is 2*(11-7) = 8
  48. var result = tape.gradient(loss, x2);
  49. Assert.AreEqual((float)result, 8f);
  50. }
  51. }
  52. [TestMethod]
  53. public void SquaredDifference_1D()
  54. {
  55. // Calcute the gradient of (x1-x2)^2
  56. // by Automatic Differentiation in Eager mode
  57. // Expected is 2*(abs(x1-x2))
  58. Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
  59. Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
  60. float[] expected = new float[]
  61. {
  62. (29-1) * 2,
  63. (27-3) * 2,
  64. (23-5) * 2,
  65. (7-21) * 2,
  66. (11-19) * 2,
  67. (13-17) * 2
  68. };
  69. // Sanity check
  70. using (var tape = tf.GradientTape())
  71. {
  72. tape.watch(x1);
  73. tape.watch(x2);
  74. var loss = tf.multiply((x1 - x2), (x1 - x2));
  75. var result = tape.gradient(loss, x2);
  76. CollectionAssert.AreEqual(result.ToArray<float>(), expected);
  77. }
  78. // Actual test
  79. using (var tape = tf.GradientTape())
  80. {
  81. tape.watch(x1);
  82. tape.watch(x2);
  83. var loss = tf.squared_difference(x1, x2);
  84. var result = tape.gradient(loss, x2);
  85. CollectionAssert.AreEqual(result.ToArray<float>(), expected);
  86. }
  87. }
  88. /// <summary>
  89. /// Calcute the higher derivative gradient of w * w * w
  90. /// 高阶梯度
  91. /// </summary>
  92. [TestMethod]
  93. public void HighGradient()
  94. {
  95. var x = tf.Variable(1.0f);
  96. using var tape1 = tf.GradientTape();
  97. using var tape2 = tf.GradientTape();
  98. var y = x * x * x;
  99. var dy_dx = tape2.gradient(y, x);
  100. Assert.AreEqual((float)dy_dx, 3.0f);
  101. var d2y_d2x = tape1.gradient(dy_dx, x);
  102. Assert.AreEqual((float)d2y_d2x, 6.0f);
  103. }
  104. [TestMethod]
  105. public void ConstantMultiply()
  106. {
  107. var x = tf.ones((2, 2));
  108. using var tape = tf.GradientTape();
  109. tape.watch(x);
  110. var y = tf.reduce_sum(x);
  111. var z = tf.multiply(y, y);
  112. var dz_dx = tape.gradient(z, x);
  113. var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
  114. Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));
  115. }
  116. [TestMethod]
  117. public void PersistentTape()
  118. {
  119. var x = tf.ones((2, 2));
  120. using var tape = tf.GradientTape(persistent: true);
  121. tape.watch(x);
  122. var y = tf.reduce_sum(x);
  123. var z = tf.multiply(y, y);
  124. var dz_dx = tape.gradient(z, x);
  125. var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
  126. Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));
  127. var dz_dy = tape.gradient(z, y);
  128. Assert.AreEqual((float)dz_dy, 8.0f);
  129. }
  130. [TestMethod]
  131. public void ConditionalMultiply()
  132. {
  133. Func<Tensor, int, Tensor> func = (x, y) =>
  134. {
  135. Tensor output = tf.constant(1.0f);
  136. foreach (var i in range(y))
  137. {
  138. if (i > 1)
  139. output = tf.multiply(output, x);
  140. }
  141. return output;
  142. };
  143. Func<Tensor, int, Tensor> grad = (x, y) =>
  144. {
  145. using var tape = tf.GradientTape();
  146. tape.watch(x);
  147. var output = func(x, y);
  148. var grad = tape.gradient(output, x);
  149. return grad;
  150. };
  151. var x = tf.constant(2.0f);
  152. var result = grad(x, 4);
  153. Assert.AreEqual((float)result, 4.0f);
  154. }
  155. }
  156. }