You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SigmoidFocalCrossEntropy.cs 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. using static HDF.PInvoke.H5L.info_t;
  2. namespace Tensorflow.Keras.Losses;
  3. public class SigmoidFocalCrossEntropy : LossFunctionWrapper
  4. {
  5. float _alpha;
  6. float _gamma;
  7. public SigmoidFocalCrossEntropy(bool from_logits = false,
  8. float alpha = 0.25f,
  9. float gamma = 2.0f,
  10. string reduction = "none",
  11. string name = "sigmoid_focal_crossentropy") :
  12. base(reduction: reduction,
  13. name: name,
  14. from_logits: from_logits)
  15. {
  16. _alpha = alpha;
  17. _gamma = gamma;
  18. }
  19. public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
  20. {
  21. y_true = tf.cast(y_true, dtype: y_pred.dtype);
  22. var ce = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits);
  23. var pred_prob = from_logits ? tf.sigmoid(y_pred) : y_pred;
  24. var p_t = (y_true * pred_prob) + ((1f - y_true) * (1f - pred_prob));
  25. Tensor alpha_factor = constant_op.constant(1.0f);
  26. Tensor modulating_factor = constant_op.constant(1.0f);
  27. if(_alpha > 0)
  28. {
  29. var alpha = tf.cast(constant_op.constant(_alpha), dtype: y_true.dtype);
  30. alpha_factor = y_true * alpha + (1f - y_true) * (1f - alpha);
  31. }
  32. if (_gamma > 0)
  33. {
  34. var gamma = tf.cast(constant_op.constant(_gamma), dtype: y_true.dtype);
  35. modulating_factor = tf.pow(1f - p_t, gamma);
  36. }
  37. return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis = -1);
  38. }
  39. }