using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow.NumPy; using static Tensorflow.Binding; namespace Tensorflow.Keras.UnitTest.Losses; [TestClass] public class LossesTest : EagerModeTestBase { /// /// https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy /// [TestMethod] public void BinaryCrossentropy() { // Example 1: (batch_size = 1, number of samples = 4) var y_true = tf.constant(new float[] { 0, 1, 0, 0 }); var y_pred = tf.constant(new float[] { -18.6f, 0.51f, 2.94f, -12.8f }); var bce = tf.keras.losses.BinaryCrossentropy(from_logits: true); var loss = bce.Call(y_true, y_pred); Assert.AreEqual((float)loss, 0.865458f); // Example 2: (batch_size = 2, number of samples = 4) y_true = tf.constant(new float[,] { { 0, 1 }, { 0, 0 } }); y_pred = tf.constant(new float[,] { { -18.6f, 0.51f }, { 2.94f, -12.8f } }); bce = tf.keras.losses.BinaryCrossentropy(from_logits: true); loss = bce.Call(y_true, y_pred); Assert.AreEqual((float)loss, 0.865458f); // Using 'sample_weight' attribute loss = bce.Call(y_true, y_pred, sample_weight: tf.constant(new[] { 0.8f, 0.2f })); Assert.AreEqual((float)loss, 0.2436386f); // Using 'sum' reduction` type. bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.SUM); loss = bce.Call(y_true, y_pred); Assert.AreEqual((float)loss, 1.730916f); // Using 'none' reduction type. bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.NONE); loss = bce.Call(y_true, y_pred); Assert.IsTrue(new NDArray(new float[] { 0.23515666f, 1.4957594f }) == loss.numpy()); } /// /// https://www.tensorflow.org/addons/api_docs/python/tfa/losses/SigmoidFocalCrossEntropy /// [TestMethod] public void SigmoidFocalCrossEntropy() { var y_true = np.expand_dims(np.array(new[] { 1.0f, 1.0f, 0 })); var y_pred = np.expand_dims(np.array(new[] { 0.97f, 0.91f, 0.03f })); var bce = tf.keras.losses.SigmoidFocalCrossEntropy(); var loss = bce.Call(y_true, y_pred); Assert.AreEqual(new[] { 6.8532745e-06f, 1.909787e-04f, 2.0559824e-05f }, loss.numpy()); } }