You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientEagerTest.cs 5.7 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow;
  6. using Tensorflow.UnitTest;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest.Gradient
  9. {
  10. [TestClass]
  11. public class GradientEagerTest : EagerModeTestBase
  12. {
  13. [TestMethod]
  14. public void ConstantSquare()
  15. {
  16. // Calcute the gradient of w * w
  17. // by Automatic Differentiation in Eager mode
  18. // in tensorflow.net 2.x that is in development intensively
  19. var w = tf.constant(1.5f);
  20. using var tape = tf.GradientTape();
  21. tape.watch(w);
  22. var loss = w * w;
  23. var grad = tape.gradient(loss, w);
  24. Assert.AreEqual((float)grad, 3.0f);
  25. }
  26. [TestMethod]
  27. public void SquaredDifference_Constant()
  28. {
  29. // Calcute the gradient of (x1-x2)^2
  30. // by Automatic Differentiation in Eager mode
  31. var x1 = tf.constant(7f);
  32. var x2 = tf.constant(11f);
  33. // Sanity check
  34. using (var tape = tf.GradientTape())
  35. {
  36. tape.watch(x2);
  37. var loss = tf.multiply((x1 - x2), (x1 - x2));
  38. var result = tape.gradient(loss, x2);
  39. // Expected is 2*(11-7) = 8
  40. Assert.AreEqual((float)result, 8f);
  41. }
  42. // Actual test
  43. using (var tape = tf.GradientTape())
  44. {
  45. tape.watch(x2);
  46. var loss = tf.squared_difference(x1, x2);
  47. // Expected is 2*(11-7) = 8
  48. var result = tape.gradient(loss, x2);
  49. Assert.AreEqual((float)result, 8f);
  50. }
  51. }
  52. [Ignore]
  53. [TestMethod]
  54. public void SquaredDifference_1D()
  55. {
  56. // Calcute the gradient of (x1-x2)^2
  57. // by Automatic Differentiation in Eager mode
  58. // Expected is 2*(abs(x1-x2))
  59. Tensor x1 = new NumSharp.NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
  60. Tensor x2 = new NumSharp.NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
  61. float[] expected = new float[] {
  62. (29-1) * 2,
  63. (27-3) * 2,
  64. (23-5) * 2,
  65. (7-21) * 2,
  66. (11-19) * 2,
  67. (13-17) * 2
  68. };
  69. // Sanity check
  70. using (var tape = tf.GradientTape())
  71. {
  72. tape.watch(x1);
  73. tape.watch(x2);
  74. var loss = tf.multiply((x1 - x2), (x1 - x2));
  75. var result = tape.gradient(loss, x2);
  76. CollectionAssert.AreEqual(result.ToArray<float>(), expected);
  77. }
  78. // Actual test
  79. using (var tape = tf.GradientTape())
  80. {
  81. tape.watch(x1);
  82. tape.watch(x2);
  83. var loss = tf.squared_difference(x1, x2);
  84. var result = tape.gradient(loss, x2);
  85. CollectionAssert.AreEqual(result.ToArray<float>(), expected);
  86. }
  87. }
  88. /// <summary>
  89. /// Calcute the gradient of w * w * w
  90. /// 高阶梯度
  91. /// </summary>
  92. [TestMethod]
  93. public void HighGradient()
  94. {
  95. var x = tf.Variable(1.0f);
  96. using var tape1 = tf.GradientTape();
  97. using var tape2 = tf.GradientTape();
  98. var y = x * x * x;
  99. tape2.Dispose();
  100. var dy_dx = tape2.gradient(y, x);
  101. Assert.AreEqual((float)dy_dx, 3.0f);
  102. tape1.Dispose();
  103. var d2y_d2x = tape1.gradient(dy_dx, x);
  104. Assert.AreEqual((float)d2y_d2x, 6.0f);
  105. }
  106. [TestMethod]
  107. public void ConstantMultiply()
  108. {
  109. var x = tf.ones((2, 2));
  110. using var tape = tf.GradientTape();
  111. tape.watch(x);
  112. var y = tf.reduce_sum(x);
  113. var z = tf.multiply(y, y);
  114. var dz_dx = tape.gradient(z, x);
  115. var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
  116. Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));
  117. }
  118. [TestMethod]
  119. public void PersistentTape()
  120. {
  121. var x = tf.ones((2, 2));
  122. using var tape = tf.GradientTape(persistent: true);
  123. tape.watch(x);
  124. var y = tf.reduce_sum(x);
  125. var z = tf.multiply(y, y);
  126. tape.Dispose();
  127. var dz_dx = tape.gradient(z, x);
  128. var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };
  129. Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected));
  130. var dz_dy = tape.gradient(z, y);
  131. Assert.AreEqual((float)dz_dy, 8.0f);
  132. }
  133. [TestMethod]
  134. public void ConditionalMultiply()
  135. {
  136. Func<Tensor, int, Tensor> func = (x, y) =>
  137. {
  138. Tensor output = tf.constant(1.0f);
  139. foreach (var i in range(y))
  140. {
  141. if (i > 1)
  142. output = tf.multiply(output, x);
  143. }
  144. return output;
  145. };
  146. Func<Tensor, int, Tensor> grad = (x, y) =>
  147. {
  148. using var tape = tf.GradientTape();
  149. tape.watch(x);
  150. var output = func(x, y);
  151. var grad = tape.gradient(output, x);
  152. return grad;
  153. };
  154. var x = tf.constant(2.0f);
  155. var result = grad(x, 4);
  156. Assert.AreEqual((float)result, 4.0f);
  157. }
  158. }
  159. }