@@ -38,4 +38,19 @@ public interface ILossesApi | |||||
ILossFunc LogCosh(string reduction = null, | ILossFunc LogCosh(string reduction = null, | ||||
string name = null); | string name = null); | ||||
/// <summary> | |||||
/// Implements the focal loss function. | |||||
/// </summary> | |||||
/// <param name="from_logits"></param> | |||||
/// <param name="alpha"></param> | |||||
/// <param name="gamma"></param> | |||||
/// <param name="reduction"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false, | |||||
float alpha = 0.25f, | |||||
float gamma = 2.0f, | |||||
string reduction = "none", | |||||
string name = "sigmoid_focal_crossentropy"); | |||||
} | } |
@@ -37,5 +37,16 @@ | |||||
public ILossFunc LogCosh(string reduction = null, string name = null) | public ILossFunc LogCosh(string reduction = null, string name = null) | ||||
=> new LogCosh(reduction: reduction, name: name); | => new LogCosh(reduction: reduction, name: name); | ||||
public ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false, | |||||
float alpha = 0.25F, | |||||
float gamma = 2, | |||||
string reduction = "none", | |||||
string name = "sigmoid_focal_crossentropy") | |||||
=> new SigmoidFocalCrossEntropy(from_logits: from_logits, | |||||
alpha: alpha, | |||||
gamma: gamma, | |||||
reduction: reduction, | |||||
name: name); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,48 @@ | |||||
using static HDF.PInvoke.H5L.info_t; | |||||
namespace Tensorflow.Keras.Losses; | |||||
public class SigmoidFocalCrossEntropy : LossFunctionWrapper, ILossFunc | |||||
{ | |||||
float _alpha; | |||||
float _gamma; | |||||
public SigmoidFocalCrossEntropy(bool from_logits = false, | |||||
float alpha = 0.25f, | |||||
float gamma = 2.0f, | |||||
string reduction = "none", | |||||
string name = "sigmoid_focal_crossentropy") : | |||||
base(reduction: reduction, | |||||
name: name, | |||||
from_logits: from_logits) | |||||
{ | |||||
_alpha = alpha; | |||||
_gamma = gamma; | |||||
} | |||||
public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) | |||||
{ | |||||
y_true = tf.cast(y_true, dtype: y_pred.dtype); | |||||
var ce = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits); | |||||
var pred_prob = from_logits ? tf.sigmoid(y_pred) : y_pred; | |||||
var p_t = (y_true * pred_prob) + ((1f - y_true) * (1f - pred_prob)); | |||||
Tensor alpha_factor = constant_op.constant(1.0f); | |||||
Tensor modulating_factor = constant_op.constant(1.0f); | |||||
if(_alpha > 0) | |||||
{ | |||||
var alpha = tf.cast(constant_op.constant(_alpha), dtype: y_true.dtype); | |||||
alpha_factor = y_true * alpha + (1f - y_true) * (1f - alpha); | |||||
} | |||||
if (_gamma > 0) | |||||
{ | |||||
var gamma = tf.cast(constant_op.constant(_gamma), dtype: y_true.dtype); | |||||
modulating_factor = tf.pow(1f - p_t, gamma); | |||||
} | |||||
return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis = -1); | |||||
} | |||||
} |
@@ -5,6 +5,7 @@ using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.NumPy; | |||||
using TensorFlowNET.Keras.UnitTest; | using TensorFlowNET.Keras.UnitTest; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
@@ -48,4 +49,17 @@ public class LossesTest : EagerModeTestBase | |||||
loss = bce.Call(y_true, y_pred); | loss = bce.Call(y_true, y_pred); | ||||
Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f}, loss.numpy()); | Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f}, loss.numpy()); | ||||
} | } | ||||
/// <summary> | |||||
/// https://www.tensorflow.org/addons/api_docs/python/tfa/losses/SigmoidFocalCrossEntropy | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void SigmoidFocalCrossEntropy() | |||||
{ | |||||
var y_true = np.expand_dims(np.array(new[] { 1.0f, 1.0f, 0 })); | |||||
var y_pred = np.expand_dims(np.array(new[] { 0.97f, 0.91f, 0.03f })); | |||||
var bce = tf.keras.losses.SigmoidFocalCrossEntropy(); | |||||
var loss = bce.Call(y_true, y_pred); | |||||
Assert.AreEqual(new[] { 6.8532745e-06f, 1.909787e-04f, 2.0559824e-05f }, loss.numpy()); | |||||
} | |||||
} | } |