You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MetricsApi.cs 1.2 kB

1234567891011121314151617181920212223242526272829
  1. namespace Tensorflow.Keras.Metrics
  2. {
  3. public class MetricsApi
  4. {
  5. /// <summary>
  6. /// Calculates how often predictions matches integer labels.
  7. /// </summary>
  8. /// <param name="y_true">Integer ground truth values.</param>
  9. /// <param name="y_pred">The prediction values.</param>
  10. /// <returns>Sparse categorical accuracy values.</returns>
  11. public Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred)
  12. {
  13. var y_pred_rank = y_pred.TensorShape.ndim;
  14. var y_true_rank = y_true.TensorShape.ndim;
  15. // If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
  16. if (y_true_rank != -1 && y_pred_rank != -1
  17. && y_true.shape.Length == y_pred.shape.Length)
  18. y_true = array_ops.squeeze(y_true, axis: new[] { -1 });
  19. y_pred = math_ops.argmax(y_pred, -1);
  20. // If the predicted output and actual output types don't match, force cast them
  21. // to match.
  22. if (y_pred.dtype != y_true.dtype)
  23. y_pred = math_ops.cast(y_pred, y_true.dtype);
  24. return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT);
  25. }
  26. }
  27. }