You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MeanSquaredError.Test.cs 1.9 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using static Tensorflow.KerasApi;
  4. namespace Tensorflow.Keras.UnitTest.Layers
  5. {
  6. [TestClass]
  7. public class MeanSquaredErrorTest
  8. {
  9. //https://keras.io/api/losses/regression_losses/#meansquarederror-class
  10. private NDArray y_true = new double[,] { { 0.0, 1.0 }, { 0.0, 0.0 } };
  11. private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } };
  12. [TestMethod]
  13. public void Mse_Double()
  14. {
  15. var mse = keras.losses.MeanSquaredError();
  16. var call = mse.Call(y_true, y_pred);
  17. Assert.AreEqual(call.numpy(), 0.5);
  18. }
  19. [TestMethod]
  20. public void Mse_Float()
  21. {
  22. NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } };
  23. NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } };
  24. var mse = keras.losses.MeanSquaredError();
  25. var call = mse.Call(y_true_float, y_pred_float);
  26. Assert.AreEqual(call.numpy(), 0.5f);
  27. }
  28. [TestMethod]
  29. public void Mse_Sample_Weight()
  30. {
  31. var mse = keras.losses.MeanSquaredError();
  32. var call = mse.Call(y_true, y_pred, sample_weight: (NDArray)new double[] { 0.7, 0.3 });
  33. Assert.AreEqual(call.numpy(), 0.25);
  34. }
  35. [TestMethod]
  36. public void Mse_Reduction_SUM()
  37. {
  38. var mse = keras.losses.MeanSquaredError(reduction: Reduction.SUM);
  39. var call = mse.Call(y_true, y_pred);
  40. Assert.AreEqual(call.numpy(), 1.0);
  41. }
  42. [TestMethod]
  43. public void Mse_Reduction_NONE()
  44. {
  45. var mse = keras.losses.MeanSquaredError(reduction: Reduction.NONE);
  46. var call = mse.Call(y_true, y_pred);
  47. Assert.AreEqual(call.numpy(), new double[] { 0.5, 0.5 });
  48. }
  49. }
  50. }