You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MetricsApi.cs 1.5 kB

1234567891011121314151617181920212223242526272829303132333435
  1. namespace Tensorflow.Keras.Metrics
  2. {
  3. public class MetricsApi
  4. {
  5. public Tensor categorical_accuracy(Tensor y_true, Tensor y_pred)
  6. {
  7. var eql = math_ops.equal(math_ops.argmax(y_true, -1), math_ops.argmax(y_pred, -1));
  8. return math_ops.cast(eql, TF_DataType.TF_FLOAT);
  9. }
  10. /// <summary>
  11. /// Calculates how often predictions matches integer labels.
  12. /// </summary>
  13. /// <param name="y_true">Integer ground truth values.</param>
  14. /// <param name="y_pred">The prediction values.</param>
  15. /// <returns>Sparse categorical accuracy values.</returns>
  16. public Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred)
  17. {
  18. var y_pred_rank = y_pred.TensorShape.ndim;
  19. var y_true_rank = y_true.TensorShape.ndim;
  20. // If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
  21. if (y_true_rank != -1 && y_pred_rank != -1
  22. && y_true.shape.Length == y_pred.shape.Length)
  23. y_true = array_ops.squeeze(y_true, axis: new[] { -1 });
  24. y_pred = math_ops.argmax(y_pred, -1);
  25. // If the predicted output and actual output types don't match, force cast them
  26. // to match.
  27. if (y_pred.dtype != y_true.dtype)
  28. y_pred = math_ops.cast(y_pred, y_true.dtype);
  29. return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT);
  30. }
  31. }
  32. }