You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LossesTest.cs 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using static Tensorflow.Binding;
  4. namespace Tensorflow.Keras.UnitTest.Losses;
  5. [TestClass]
  6. public class LossesTest : EagerModeTestBase
  7. {
  8. /// <summary>
  9. /// https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy
  10. /// </summary>
  11. [TestMethod]
  12. public void BinaryCrossentropy()
  13. {
  14. // Example 1: (batch_size = 1, number of samples = 4)
  15. var y_true = tf.constant(new float[] { 0, 1, 0, 0 });
  16. var y_pred = tf.constant(new float[] { -18.6f, 0.51f, 2.94f, -12.8f });
  17. var bce = tf.keras.losses.BinaryCrossentropy(from_logits: true);
  18. var loss = bce.Call(y_true, y_pred);
  19. Assert.AreEqual((float)loss, 0.865458f);
  20. // Example 2: (batch_size = 2, number of samples = 4)
  21. y_true = tf.constant(new float[,] { { 0, 1 }, { 0, 0 } });
  22. y_pred = tf.constant(new float[,] { { -18.6f, 0.51f }, { 2.94f, -12.8f } });
  23. bce = tf.keras.losses.BinaryCrossentropy(from_logits: true);
  24. loss = bce.Call(y_true, y_pred);
  25. Assert.AreEqual((float)loss, 0.865458f);
  26. // Using 'sample_weight' attribute
  27. loss = bce.Call(y_true, y_pred, sample_weight: tf.constant(new[] { 0.8f, 0.2f }));
  28. Assert.AreEqual((float)loss, 0.2436386f);
  29. // Using 'sum' reduction` type.
  30. bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.SUM);
  31. loss = bce.Call(y_true, y_pred);
  32. Assert.AreEqual((float)loss, 1.730916f);
  33. // Using 'none' reduction type.
  34. bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.NONE);
  35. loss = bce.Call(y_true, y_pred);
  36. Assert.IsTrue(new NDArray(new float[] { 0.23515666f, 1.4957594f }) == loss.numpy());
  37. }
  38. /// <summary>
  39. /// https://www.tensorflow.org/addons/api_docs/python/tfa/losses/SigmoidFocalCrossEntropy
  40. /// </summary>
  41. [TestMethod]
  42. public void SigmoidFocalCrossEntropy()
  43. {
  44. var y_true = np.expand_dims(np.array(new[] { 1.0f, 1.0f, 0 }));
  45. var y_pred = np.expand_dims(np.array(new[] { 0.97f, 0.91f, 0.03f }));
  46. var bce = tf.keras.losses.SigmoidFocalCrossEntropy();
  47. var loss = bce.Call(y_true, y_pred);
  48. Assert.AreEqual(new[] { 6.8532745e-06f, 1.909787e-04f, 2.0559824e-05f }, loss.numpy());
  49. }
  50. }