diff --git a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs index 20361e8f..3166da0f 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs @@ -1,5 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; +using System.Collections.Generic; using System.Linq; using Tensorflow; using Tensorflow.UnitTest; @@ -24,6 +25,81 @@ namespace TensorFlowNET.UnitTest.Gradient Assert.AreEqual((float)grad, 3.0f); } + [Ignore] + [TestMethod] + public void SquaredDifference_Constant() + { + // Calcute the gradient of (x1-x2)^2 + // by Automatic Differentiation in Eager mode + var x1 = tf.constant(7f); + var x2 = tf.constant(11f); + + // Sanity check + using (var tape = tf.GradientTape()) + { + tape.watch(x2); + var loss = tf.multiply((x1 - x2), (x1 - x2)); + + var result = tape.gradient(loss, x2); + // Expected is 2*(11-7) = 8 + Assert.AreEqual((float)result, 8f); + } + + // Actual test + using (var tape = tf.GradientTape()) + { + tape.watch(x2); + var loss = tf.squared_difference(x1, x2); + + // Expected is 2*(11-7) = 8 + var result = tape.gradient(loss, x2); + Assert.AreEqual((float)result, 8f); + } + } + + + [Ignore] + [TestMethod] + public void SquaredDifference_1D() + { + // Calcute the gradient of (x1-x2)^2 + // by Automatic Differentiation in Eager mode + // Expected is 2*(abs(x1-x2)) + Tensor x1 = new NumSharp.NDArray( new float[] { 1, 3, 5, 21, 19, 17 }); + Tensor x2 = new NumSharp.NDArray(new float[] { 29, 27, 23, 7, 11, 13 }); + float[] expected = new float[] { + (29-1) * 2, + (27-3) * 2, + (23-5) * 2, + (7-21) * 2, + (11-19) * 2, + (13-17) * 2 + }; + + // Sanity check + using (var tape = tf.GradientTape()) + { + tape.watch(x1); + tape.watch(x2); + var loss = tf.multiply((x1 - x2), (x1 - x2)); + + var result = tape.gradient(loss, x2); + CollectionAssert.AreEqual(result.ToArray(), expected); + } + + // Actual test + using (var tape = tf.GradientTape()) + { + tape.watch(x1); + tape.watch(x2); + var loss = tf.squared_difference(x1, x2); + + var result = tape.gradient(loss, x2); + CollectionAssert.AreEqual(result.ToArray(), expected); + } + } + + /// /// Calcute the gradient of w * w * w /// 高阶梯度