You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CosineSimilarity.cs 767 B

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122
  1. namespace Tensorflow.Keras.Losses;
  2. public class CosineSimilarity : LossFunctionWrapper
  3. {
  4. protected int axis = -1;
  5. public CosineSimilarity(
  6. string reduction = null,
  7. int axis = -1,
  8. string name = null) :
  9. base(reduction: reduction, name: name == null ? "cosine_similarity" : name)
  10. {
  11. this.axis = axis;
  12. }
  13. public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
  14. {
  15. Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis: this.axis);
  16. Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis);
  17. return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis: constant_op.constant(this.axis));
  18. }
  19. }