From f8daeae69698806a55922a974da7adc89f03a34b Mon Sep 17 00:00:00 2001 From: dataangel Date: Tue, 15 Dec 2020 00:16:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0Mes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Keras/MeanSquaredError.Test.cs | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs diff --git a/test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs new file mode 100644 index 00000000..f1c782f8 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs @@ -0,0 +1,65 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class MeanSquaredErrorTest + { + //https://keras.io/api/losses/regression_losses/#meansquarederror-class + + private NDArray y_true = new double[,] { { 0.0, 1.0 }, { 0.0, 0.0 } }; + private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } }; + + [TestMethod] + + public void Mse_Double() + { + var mse = keras.losses.MeanSquaredError(); + var call = mse.Call(y_true, y_pred); + Assert.AreEqual((NDArray)0.5, call.numpy()) ; + } + + [TestMethod] + + public void Mse_Float() + { + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; + + var mse = keras.losses.MeanSquaredError(); + var call = mse.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.5, call.numpy()); + } + + [TestMethod] + + public void Mse_Sample_Weight() + { + var mse = keras.losses.MeanSquaredError(); + var call = mse.Call(y_true, y_pred, sample_weight: (NDArray)new double[] { 0.7, 0.3 }); + Assert.AreEqual((NDArray)0.25, call.numpy()); + } + + [TestMethod] + public void Mse_Reduction_SUM() + { + var mse = keras.losses.MeanSquaredError(reduction: Reduction.SUM); + var call = mse.Call(y_true, y_pred); + Assert.AreEqual((NDArray)1.0, call.numpy()); + } + + [TestMethod] + + public void Mse_Reduction_NONE() + { + var mse = keras.losses.MeanSquaredError(reduction: Reduction.NONE); + var call = mse.Call(y_true, y_pred); + Assert.AreEqual((NDArray)new double[] { 0.5, 0.5 }, call.numpy()); + } + } +}