From 421c663bdeb4272a1d3ddebcc182747daa922235 Mon Sep 17 00:00:00 2001 From: Alexander Mishunin Date: Sat, 14 Dec 2019 02:22:20 +0300 Subject: [PATCH] A few simple gradient tests --- .../gradients_test/GradientsTest.cs | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs index 2fae1e5b..27acbad0 100644 --- a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs @@ -94,6 +94,99 @@ namespace TensorFlowNET.UnitTest.gradients_test } } + [TestMethod] + public void testSimpleGradients() + { + (T, T) evaluateDerivatives(Func f, T xval) where T : unmanaged + { + var x = tf.constant(xval); + var y = f(x); + var g = tf.gradients(y, x); + + using (var session = tf.Session()) + { + var result = session.run(new[] { y, g[0] }); + return (result[0].GetData()[0], result[1].GetData()[0]); + } + } + + void assertFloat32Equal(float expected, float actual, string msg) + { + float eps = 1e-6f; + Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); + } + + void test(string name, Func tfF, Func targetF, double[] values) + { + foreach (var x in values) + { + var (expectedY, expectedDY) = targetF(x); + + { + var (actualY, actualDY) = evaluateDerivatives(tfF, x); + Assert.AreEqual(expectedY, actualY, $"value {name}/float64 at {x}"); + Assert.AreEqual(expectedDY, actualDY, $"derivative {name}/float64 at {x}"); + } + + { + var (actualY, actualDY) = evaluateDerivatives(tfF, (float)x); + assertFloat32Equal((float)expectedY, actualY, $"value {name}/float32 at {x}"); + assertFloat32Equal((float)expectedDY, actualDY, $"derivative {name}/float32 at {x}"); + } + } + } + + test("tf.exp", + x => tf.exp(5 * x), + x => (Math.Exp(5.0 * x), 5.0 * Math.Exp(5.0 * x)), + new[] { -1.0, 0.0, 1.0, 1.5 }); + + test("tf.log", + x => tf.log(x), + x => (Math.Log(x), 1.0 / x), + new[] { 0.5, 1.0, 1.5, 2.0 }); + + test("tf.sqrt", + x => tf.sqrt(x), + x => (Math.Sqrt(x), 0.5 / Math.Sqrt(x)), + new[] { 0.5, 1.0, 1.1, 1.5, 2.0 }); + + test("tf.sin", + x => tf.sin(x), + x => (Math.Sin(x), Math.Cos(x)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.sinh", + x => tf.sinh(x), + x => (Math.Sinh(x), Math.Cosh(x)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.cos", + x => tf.cos(x), + x => (Math.Cos(x), -Math.Sin(x)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.cosh", + x => tf.cosh(x), + x => (Math.Cosh(x), Math.Sinh(x)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.tanh", + x => tf.tanh(x), + x => (Math.Tanh(x), 1.0 - Math.Pow(Math.Tanh(x), 2.0)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.maximum", + x => tf.maximum(x, tf.constant(0.0, dtype: x.dtype)), + x => (Math.Max(x, 0.0), (x > 0.0) ? 1.0 : 0.0), + new[] { -1.0, 1.0 }); + + test("tf.minimum", + x => tf.minimum(x, tf.constant(0.0, dtype: x.dtype)), + x => (Math.Min(x, 0.0), (x < 0.0) ? 1.0 : 0.0), + new[] { -1.0, 1.0 }); + } + [TestMethod] public void testTanhGradient() {