diff --git a/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs b/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs
index c4249336..4c92512d 100644
--- a/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs
@@ -38,4 +38,19 @@ public interface ILossesApi
ILossFunc LogCosh(string reduction = null,
string name = null);
+
+ ///
+ /// Implements the focal loss function.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false,
+ float alpha = 0.25f,
+ float gamma = 2.0f,
+ string reduction = "none",
+ string name = "sigmoid_focal_crossentropy");
}
diff --git a/src/TensorFlowNET.Keras/Losses/LossesApi.cs b/src/TensorFlowNET.Keras/Losses/LossesApi.cs
index 29e15e53..79f16a2e 100644
--- a/src/TensorFlowNET.Keras/Losses/LossesApi.cs
+++ b/src/TensorFlowNET.Keras/Losses/LossesApi.cs
@@ -37,5 +37,16 @@
public ILossFunc LogCosh(string reduction = null, string name = null)
=> new LogCosh(reduction: reduction, name: name);
+
+ public ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false,
+ float alpha = 0.25F,
+ float gamma = 2,
+ string reduction = "none",
+ string name = "sigmoid_focal_crossentropy")
+ => new SigmoidFocalCrossEntropy(from_logits: from_logits,
+ alpha: alpha,
+ gamma: gamma,
+ reduction: reduction,
+ name: name);
}
}
diff --git a/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs b/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs
new file mode 100644
index 00000000..7ac3fa0b
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs
@@ -0,0 +1,48 @@
+using static HDF.PInvoke.H5L.info_t;
+
+namespace Tensorflow.Keras.Losses;
+
+public class SigmoidFocalCrossEntropy : LossFunctionWrapper, ILossFunc
+{
+ float _alpha;
+ float _gamma;
+
+ public SigmoidFocalCrossEntropy(bool from_logits = false,
+ float alpha = 0.25f,
+ float gamma = 2.0f,
+ string reduction = "none",
+ string name = "sigmoid_focal_crossentropy") :
+ base(reduction: reduction,
+ name: name,
+ from_logits: from_logits)
+ {
+ _alpha = alpha;
+ _gamma = gamma;
+ }
+
+
+ public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
+ {
+ y_true = tf.cast(y_true, dtype: y_pred.dtype);
+ var ce = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits);
+ var pred_prob = from_logits ? tf.sigmoid(y_pred) : y_pred;
+
+ var p_t = (y_true * pred_prob) + ((1f - y_true) * (1f - pred_prob));
+ Tensor alpha_factor = constant_op.constant(1.0f);
+ Tensor modulating_factor = constant_op.constant(1.0f);
+
+ if(_alpha > 0)
+ {
+ var alpha = tf.cast(constant_op.constant(_alpha), dtype: y_true.dtype);
+ alpha_factor = y_true * alpha + (1f - y_true) * (1f - alpha);
+ }
+
+ if (_gamma > 0)
+ {
+ var gamma = tf.cast(constant_op.constant(_gamma), dtype: y_true.dtype);
+ modulating_factor = tf.pow(1f - p_t, gamma);
+ }
+
+ return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis = -1);
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs b/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs
index b19f0203..98fa1de1 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs
@@ -5,6 +5,7 @@ using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow;
+using Tensorflow.NumPy;
using TensorFlowNET.Keras.UnitTest;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -48,4 +49,17 @@ public class LossesTest : EagerModeTestBase
loss = bce.Call(y_true, y_pred);
Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f}, loss.numpy());
}
+
+ ///
+ /// https://www.tensorflow.org/addons/api_docs/python/tfa/losses/SigmoidFocalCrossEntropy
+ ///
+ [TestMethod]
+ public void SigmoidFocalCrossEntropy()
+ {
+ var y_true = np.expand_dims(np.array(new[] { 1.0f, 1.0f, 0 }));
+ var y_pred = np.expand_dims(np.array(new[] { 0.97f, 0.91f, 0.03f }));
+ var bce = tf.keras.losses.SigmoidFocalCrossEntropy();
+ var loss = bce.Call(y_true, y_pred);
+ Assert.AreEqual(new[] { 6.8532745e-06f, 1.909787e-04f, 2.0559824e-05f }, loss.numpy());
+ }
}