diff --git a/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs b/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs index c4249336..4c92512d 100644 --- a/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs +++ b/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs @@ -38,4 +38,19 @@ public interface ILossesApi ILossFunc LogCosh(string reduction = null, string name = null); + + /// + /// Implements the focal loss function. + /// + /// + /// + /// + /// + /// + /// + ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false, + float alpha = 0.25f, + float gamma = 2.0f, + string reduction = "none", + string name = "sigmoid_focal_crossentropy"); } diff --git a/src/TensorFlowNET.Keras/Losses/LossesApi.cs b/src/TensorFlowNET.Keras/Losses/LossesApi.cs index 29e15e53..79f16a2e 100644 --- a/src/TensorFlowNET.Keras/Losses/LossesApi.cs +++ b/src/TensorFlowNET.Keras/Losses/LossesApi.cs @@ -37,5 +37,16 @@ public ILossFunc LogCosh(string reduction = null, string name = null) => new LogCosh(reduction: reduction, name: name); + + public ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false, + float alpha = 0.25F, + float gamma = 2, + string reduction = "none", + string name = "sigmoid_focal_crossentropy") + => new SigmoidFocalCrossEntropy(from_logits: from_logits, + alpha: alpha, + gamma: gamma, + reduction: reduction, + name: name); } } diff --git a/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs b/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs new file mode 100644 index 00000000..7ac3fa0b --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs @@ -0,0 +1,48 @@ +using static HDF.PInvoke.H5L.info_t; + +namespace Tensorflow.Keras.Losses; + +public class SigmoidFocalCrossEntropy : LossFunctionWrapper, ILossFunc +{ + float _alpha; + float _gamma; + + public SigmoidFocalCrossEntropy(bool from_logits = false, + float alpha = 0.25f, + float gamma = 2.0f, + string reduction = "none", + string name = "sigmoid_focal_crossentropy") : + base(reduction: reduction, + name: name, + from_logits: from_logits) + { + _alpha = alpha; + _gamma = gamma; + } + + + public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) + { + y_true = tf.cast(y_true, dtype: y_pred.dtype); + var ce = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits); + var pred_prob = from_logits ? tf.sigmoid(y_pred) : y_pred; + + var p_t = (y_true * pred_prob) + ((1f - y_true) * (1f - pred_prob)); + Tensor alpha_factor = constant_op.constant(1.0f); + Tensor modulating_factor = constant_op.constant(1.0f); + + if(_alpha > 0) + { + var alpha = tf.cast(constant_op.constant(_alpha), dtype: y_true.dtype); + alpha_factor = y_true * alpha + (1f - y_true) * (1f - alpha); + } + + if (_gamma > 0) + { + var gamma = tf.cast(constant_op.constant(_gamma), dtype: y_true.dtype); + modulating_factor = tf.pow(1f - p_t, gamma); + } + + return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis = -1); + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs b/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs index b19f0203..98fa1de1 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Text; using System.Threading.Tasks; using Tensorflow; +using Tensorflow.NumPy; using TensorFlowNET.Keras.UnitTest; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -48,4 +49,17 @@ public class LossesTest : EagerModeTestBase loss = bce.Call(y_true, y_pred); Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f}, loss.numpy()); } + + /// + /// https://www.tensorflow.org/addons/api_docs/python/tfa/losses/SigmoidFocalCrossEntropy + /// + [TestMethod] + public void SigmoidFocalCrossEntropy() + { + var y_true = np.expand_dims(np.array(new[] { 1.0f, 1.0f, 0 })); + var y_pred = np.expand_dims(np.array(new[] { 0.97f, 0.91f, 0.03f })); + var bce = tf.keras.losses.SigmoidFocalCrossEntropy(); + var loss = bce.Call(y_true, y_pred); + Assert.AreEqual(new[] { 6.8532745e-06f, 1.909787e-04f, 2.0559824e-05f }, loss.numpy()); + } }