You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientEagerTest.cs 5.7 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow;
  6. using Tensorflow.NumPy;
  7. using Tensorflow.UnitTest;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest.Gradient
  10. {
  11. [TestClass]
  12. public class GradientEagerTest : EagerModeTestBase
  13. {
  14. [TestMethod]
  15. public void ConstantSquare()
  16. {
  17. // Calcute the gradient of w * w
  18. // by Automatic Differentiation in Eager mode
  19. // in tensorflow.net 2.x that is in development intensively
  20. var w = tf.constant(1.5f);
  21. using var tape = tf.GradientTape();
  22. tape.watch(w);
  23. var loss = w * w;
  24. var grad = tape.gradient(loss, w);
  25. Assert.AreEqual((float)grad, 3.0f);
  26. }
  27. [TestMethod]
  28. public void SquaredDifference_Constant()
  29. {
  30. // Calcute the gradient of (x1-x2)^2
  31. // by Automatic Differentiation in Eager mode
  32. var x1 = tf.constant(7f);
  33. var x2 = tf.constant(11f);
  34. // Sanity check
  35. using (var tape = tf.GradientTape())
  36. {
  37. tape.watch(x2);
  38. var loss = tf.multiply((x1 - x2), (x1 - x2));
  39. var result = tape.gradient(loss, x2);
  40. // Expected is 2*(11-7) = 8
  41. Assert.AreEqual((float)result, 8f);
  42. }
  43. // Actual test
  44. using (var tape = tf.GradientTape())
  45. {
  46. tape.watch(x2);
  47. var loss = tf.squared_difference(x1, x2);
  48. // Expected is 2*(11-7) = 8
  49. var result = tape.gradient(loss, x2);
  50. Assert.AreEqual((float)result, 8f);
  51. }
  52. }
  53. [Ignore]
  54. [TestMethod]
  55. public void SquaredDifference_1D()
  56. {
  57. // Calcute the gradient of (x1-x2)^2
  58. // by Automatic Differentiation in Eager mode
  59. // Expected is 2*(abs(x1-x2))
  60. Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
  61. Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
  62. float[] expected = new float[] {
  63. (29-1) * 2,
  64. (27-3) * 2,
  65. (23-5) * 2,
  66. (7-21) * 2,
  67. (11-19) * 2,
  68. (13-17) * 2
  69. };
  70. // Sanity check
  71. using (var tape = tf.GradientTape())
  72. {
  73. tape.watch(x1);
  74. tape.watch(x2);
  75. var loss = tf.multiply((x1 - x2), (x1 - x2));
  76. var result = tape.gradient(loss, x2);
  77. CollectionAssert.AreEqual(result.ToArray<float>(), expected);
  78. }
  79. // Actual test
  80. using (var tape = tf.GradientTape())
  81. {
  82. tape.watch(x1);
  83. tape.watch(x2);
  84. var loss = tf.squared_difference(x1, x2);
  85. var result = tape.gradient(loss, x2);
  86. CollectionAssert.AreEqual(result.ToArray<float>(), expected);
  87. }
  88. }
  89. /// <summary>
  90. /// Calcute the gradient of w * w * w
  91. /// 高阶梯度
  92. /// </summary>
  93. [TestMethod]
  94. public void HighGradient()
  95. {
  96. var x = tf.Variable(1.0f);
  97. using var tape1 = tf.GradientTape();
  98. using var tape2 = tf.GradientTape();
  99. var y = x * x * x;
  100. tape2.Dispose();
  101. var dy_dx = tape2.gradient(y, x);
  102. Assert.AreEqual((float)dy_dx, 3.0f);
  103. tape1.Dispose();
  104. var d2y_d2x = tape1.gradient(dy_dx, x);
  105. Assert.AreEqual((float)d2y_d2x, 6.0f);
  106. }
  107. [TestMethod]
  108. public void ConstantMultiply()
  109. {
  110. var x = tf.ones((2, 2));
  111. using var tape = tf.GradientTape();
  112. tape.watch(x);
  113. var y = tf.reduce_sum(x);
  114. var z = tf.multiply(y, y);
  115. var dz_dx = tape.gradient(z, x);
  116. var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
  117. Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));
  118. }
  119. [TestMethod]
  120. public void PersistentTape()
  121. {
  122. var x = tf.ones((2, 2));
  123. using var tape = tf.GradientTape(persistent: true);
  124. tape.watch(x);
  125. var y = tf.reduce_sum(x);
  126. var z = tf.multiply(y, y);
  127. tape.Dispose();
  128. var dz_dx = tape.gradient(z, x);
  129. var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
  130. Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));
  131. var dz_dy = tape.gradient(z, y);
  132. Assert.AreEqual((float)dz_dy, 8.0f);
  133. }
  134. [TestMethod]
  135. public void ConditionalMultiply()
  136. {
  137. Func<Tensor, int, Tensor> func = (x, y) =>
  138. {
  139. Tensor output = tf.constant(1.0f);
  140. foreach (var i in range(y))
  141. {
  142. if (i > 1)
  143. output = tf.multiply(output, x);
  144. }
  145. return output;
  146. };
  147. Func<Tensor, int, Tensor> grad = (x, y) =>
  148. {
  149. using var tape = tf.GradientTape();
  150. tape.watch(x);
  151. var output = func(x, y);
  152. var grad = tape.gradient(output, x);
  153. return grad;
  154. };
  155. var x = tf.constant(2.0f);
  156. var result = grad(x, 4);
  157. Assert.AreEqual((float)result, 4.0f);
  158. }
  159. }
  160. }