Browse Source

Support SigmoidFocalCrossEntropy, better for imbalanced multi-class task.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
e5dc65a7d9
4 changed files with 88 additions and 0 deletions
  1. +15
    -0
      src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs
  2. +11
    -0
      src/TensorFlowNET.Keras/Losses/LossesApi.cs
  3. +48
    -0
      src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs
  4. +14
    -0
      test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs

+ 15
- 0
src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs View File

@@ -38,4 +38,19 @@ public interface ILossesApi

ILossFunc LogCosh(string reduction = null,
string name = null);

/// <summary>
/// Implements the focal loss function.
/// </summary>
/// <param name="from_logits"></param>
/// <param name="alpha"></param>
/// <param name="gamma"></param>
/// <param name="reduction"></param>
/// <param name="name"></param>
/// <returns></returns>
ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false,
float alpha = 0.25f,
float gamma = 2.0f,
string reduction = "none",
string name = "sigmoid_focal_crossentropy");
}

+ 11
- 0
src/TensorFlowNET.Keras/Losses/LossesApi.cs View File

@@ -37,5 +37,16 @@

public ILossFunc LogCosh(string reduction = null, string name = null)
=> new LogCosh(reduction: reduction, name: name);

public ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false,
float alpha = 0.25F,
float gamma = 2,
string reduction = "none",
string name = "sigmoid_focal_crossentropy")
=> new SigmoidFocalCrossEntropy(from_logits: from_logits,
alpha: alpha,
gamma: gamma,
reduction: reduction,
name: name);
}
}

+ 48
- 0
src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs View File

@@ -0,0 +1,48 @@
using static HDF.PInvoke.H5L.info_t;

namespace Tensorflow.Keras.Losses;

public class SigmoidFocalCrossEntropy : LossFunctionWrapper, ILossFunc
{
float _alpha;
float _gamma;

public SigmoidFocalCrossEntropy(bool from_logits = false,
float alpha = 0.25f,
float gamma = 2.0f,
string reduction = "none",
string name = "sigmoid_focal_crossentropy") :
base(reduction: reduction,
name: name,
from_logits: from_logits)
{
_alpha = alpha;
_gamma = gamma;
}


public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
{
y_true = tf.cast(y_true, dtype: y_pred.dtype);
var ce = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits);
var pred_prob = from_logits ? tf.sigmoid(y_pred) : y_pred;

var p_t = (y_true * pred_prob) + ((1f - y_true) * (1f - pred_prob));
Tensor alpha_factor = constant_op.constant(1.0f);
Tensor modulating_factor = constant_op.constant(1.0f);

if(_alpha > 0)
{
var alpha = tf.cast(constant_op.constant(_alpha), dtype: y_true.dtype);
alpha_factor = y_true * alpha + (1f - y_true) * (1f - alpha);
}

if (_gamma > 0)
{
var gamma = tf.cast(constant_op.constant(_gamma), dtype: y_true.dtype);
modulating_factor = tf.pow(1f - p_t, gamma);
}

return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis = -1);
}
}

+ 14
- 0
test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs View File

@@ -5,6 +5,7 @@ using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow;
using Tensorflow.NumPy;
using TensorFlowNET.Keras.UnitTest;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -48,4 +49,17 @@ public class LossesTest : EagerModeTestBase
loss = bce.Call(y_true, y_pred);
Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f}, loss.numpy());
}

/// <summary>
/// https://www.tensorflow.org/addons/api_docs/python/tfa/losses/SigmoidFocalCrossEntropy
/// </summary>
[TestMethod]
public void SigmoidFocalCrossEntropy()
{
var y_true = np.expand_dims(np.array(new[] { 1.0f, 1.0f, 0 }));
var y_pred = np.expand_dims(np.array(new[] { 0.97f, 0.91f, 0.03f }));
var bce = tf.keras.losses.SigmoidFocalCrossEntropy();
var loss = bce.Call(y_true, y_pred);
Assert.AreEqual(new[] { 6.8532745e-06f, 1.909787e-04f, 2.0559824e-05f }, loss.numpy());
}
}

Loading…
Cancel
Save