From 788cd5c20ce1be4064e205425f07bb3e655d6bb7 Mon Sep 17 00:00:00 2001 From: Will Date: Sat, 14 Dec 2019 19:37:37 -0800 Subject: [PATCH] Fix Float64 equal compare in GradientTest --- test/TensorFlowNET.UnitTest/PythonTest.cs | 21 +++++++++++++++++++ .../gradients_test/GradientsTest.cs | 14 ++++--------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index d2ae36d7..5ceeb9b5 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -52,6 +52,17 @@ namespace TensorFlowNET.UnitTest assertItemsEqual(given, expected); } + public void assertFloat32Equal(float expected, float actual, string msg) + { + float eps = 1e-6f; + Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); + } + + public void assertFloat64Equal(double expected, double actual, string msg) + { + double eps = 1e-16f; + Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); + } public void assertEqual(object given, object expected) { @@ -70,6 +81,16 @@ namespace TensorFlowNET.UnitTest assertItemsEqual(given as ICollection, expected as ICollection); return; } + if (given is float && expected is float) + { + assertFloat32Equal((float)expected, (float)given, ""); + return; + } + if (given is double && expected is double) + { + assertFloat64Equal((double)expected, (double)given, ""); + return; + } Assert.AreEqual(expected, given); } diff --git a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs index 27acbad0..c7a26cdd 100644 --- a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs @@ -110,12 +110,6 @@ namespace TensorFlowNET.UnitTest.gradients_test } } - void assertFloat32Equal(float expected, float actual, string msg) - { - float eps = 1e-6f; - Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); - } - void test(string name, Func tfF, Func targetF, double[] values) { foreach (var x in values) @@ -124,14 +118,14 @@ namespace TensorFlowNET.UnitTest.gradients_test { var (actualY, actualDY) = evaluateDerivatives(tfF, x); - Assert.AreEqual(expectedY, actualY, $"value {name}/float64 at {x}"); - Assert.AreEqual(expectedDY, actualDY, $"derivative {name}/float64 at {x}"); + self.assertFloat64Equal(expectedY, actualY, $"value {name}/float64 at {x}"); + self.assertFloat64Equal(expectedDY, actualDY, $"derivative {name}/float64 at {x}"); } { var (actualY, actualDY) = evaluateDerivatives(tfF, (float)x); - assertFloat32Equal((float)expectedY, actualY, $"value {name}/float32 at {x}"); - assertFloat32Equal((float)expectedDY, actualDY, $"derivative {name}/float32 at {x}"); + self.assertFloat32Equal((float)expectedY, actualY, $"value {name}/float32 at {x}"); + self.assertFloat32Equal((float)expectedDY, actualDY, $"derivative {name}/float32 at {x}"); } } }