You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MeanSquaredError.Test.cs 2.1 kB

4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using Tensorflow;
  4. using Tensorflow.Keras.Losses;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. namespace TensorFlowNET.Keras.UnitTest
  8. {
  9. [TestClass]
  10. public class MeanSquaredErrorTest
  11. {
  12. //https://keras.io/api/losses/regression_losses/#meansquarederror-class
  13. private NDArray y_true = new double[,] { { 0.0, 1.0 }, { 0.0, 0.0 } };
  14. private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } };
  15. [TestMethod]
  16. public void Mse_Double()
  17. {
  18. var mse = keras.losses.MeanSquaredError();
  19. var call = mse.Call(y_true, y_pred);
  20. Assert.AreEqual((NDArray)0.5, call.numpy()) ;
  21. }
  22. [TestMethod]
  23. public void Mse_Float()
  24. {
  25. NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } };
  26. NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } };
  27. var mse = keras.losses.MeanSquaredError();
  28. var call = mse.Call(y_true_float, y_pred_float);
  29. Assert.AreEqual((NDArray)0.5, call.numpy());
  30. }
  31. [TestMethod]
  32. public void Mse_Sample_Weight()
  33. {
  34. var mse = keras.losses.MeanSquaredError();
  35. var call = mse.Call(y_true, y_pred, sample_weight: (NDArray)new double[] { 0.7, 0.3 });
  36. Assert.AreEqual((NDArray)0.25, call.numpy());
  37. }
  38. [TestMethod]
  39. public void Mse_Reduction_SUM()
  40. {
  41. var mse = keras.losses.MeanSquaredError(reduction: Reduction.SUM);
  42. var call = mse.Call(y_true, y_pred);
  43. Assert.AreEqual((NDArray)1.0, call.numpy());
  44. }
  45. [TestMethod]
  46. public void Mse_Reduction_NONE()
  47. {
  48. var mse = keras.losses.MeanSquaredError(reduction: Reduction.NONE);
  49. var call = mse.Call(y_true, y_pred);
  50. Assert.AreEqual((NDArray)new double[] { 0.5, 0.5 }, call.numpy());
  51. }
  52. }
  53. }