diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs index 98106631..6e2aa282 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; using Tensorflow; using static Tensorflow.Binding; @@ -65,15 +66,16 @@ namespace TensorFlowNET.UnitTest.ManagedAPI [TestMethod] public void GradientConcatTest() { - var X = tf.zeros(10); - var W = tf.Variable(-0.06f, name: "weight"); - var b = tf.Variable(-0.73f, name: "bias"); - var test = tf.concat(new Tensor[] { W, b }, 0); + var w1 = tf.Variable(new[] { new[] { 1f } }); + var w2 = tf.Variable(new[] { new[] { 3f } }); using var g = tf.GradientTape(); - var pred = test[0] * X + test[1]; - var gradients = g.gradient(pred, (W, b)); - Assert.IsNull(gradients.Item1); - Assert.IsNull(gradients.Item2); + var w = tf.concat(new Tensor[] { w1, w2 }, 0); + var x = tf.ones((1, 2)); + var y = tf.reduce_sum(x, 1); + var r = tf.matmul(w, x); + var gradients = g.gradient(r, w); + Assert.AreEqual((float)gradients[0][0], 2f); + Assert.AreEqual((float)gradients[1][0], 2f); } } }