@@ -0,0 +1,17 @@ | |||||
namespace Tensorflow.Keras.Metrics; | |||||
public interface IMetricFunc | |||||
{ | |||||
/// <summary> | |||||
/// Accumulates metric statistics. | |||||
/// </summary> | |||||
/// <param name="y_true"></param> | |||||
/// <param name="y_pred"></param> | |||||
/// <param name="sample_weight"></param> | |||||
/// <returns></returns> | |||||
Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); | |||||
Tensor result(); | |||||
void reset_states(); | |||||
} |
@@ -26,4 +26,13 @@ public interface IMetricsApi | |||||
/// <param name="k"></param> | /// <param name="k"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); | Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); | ||||
/// <summary> | |||||
/// Computes how often targets are in the top K predictions. | |||||
/// </summary> | |||||
/// <param name="y_true"></param> | |||||
/// <param name="y_pred"></param> | |||||
/// <param name="k"></param> | |||||
/// <returns></returns> | |||||
IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
} | } |
@@ -1,4 +1,5 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Keras.Utils; | |||||
namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
{ | { | ||||
@@ -17,6 +18,8 @@ namespace Tensorflow.Keras.Metrics | |||||
y_true = math_ops.cast(y_true, _dtype); | y_true = math_ops.cast(y_true, _dtype); | ||||
y_pred = math_ops.cast(y_pred, _dtype); | y_pred = math_ops.cast(y_pred, _dtype); | ||||
(y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); | |||||
var matches = _fn(y_true, y_pred); | var matches = _fn(y_true, y_pred); | ||||
return update_state(matches, sample_weight: sample_weight); | return update_state(matches, sample_weight: sample_weight); | ||||
} | } | ||||
@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Metrics | |||||
/// <summary> | /// <summary> | ||||
/// Encapsulates metric logic and state. | /// Encapsulates metric logic and state. | ||||
/// </summary> | /// </summary> | ||||
public class Metric : Layer | |||||
public class Metric : Layer, IMetricFunc | |||||
{ | { | ||||
protected IVariableV1 total; | protected IVariableV1 total; | ||||
protected IVariableV1 count; | protected IVariableV1 count; | ||||
@@ -1,6 +1,4 @@ | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Metrics | |||||
namespace Tensorflow.Keras.Metrics | |||||
{ | { | ||||
public class MetricsApi : IMetricsApi | public class MetricsApi : IMetricsApi | ||||
{ | { | ||||
@@ -60,5 +58,8 @@ namespace Tensorflow.Keras.Metrics | |||||
tf.math.argmax(y_true, axis: -1), y_pred, k | tf.math.argmax(y_true, axis: -1), y_pred, k | ||||
); | ); | ||||
} | } | ||||
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,12 @@ | |||||
namespace Tensorflow.Keras.Metrics; | |||||
public class TopKCategoricalAccuracy : MeanMetricWrapper | |||||
{ | |||||
public TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
: base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches( | |||||
tf.math.argmax(yt, axis: -1), yp, k), | |||||
name: name, | |||||
dtype: dtype) | |||||
{ | |||||
} | |||||
} |
@@ -15,6 +15,7 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using System.Xml.Linq; | |||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -37,15 +38,57 @@ namespace Tensorflow.Keras.Utils | |||||
}); | }); | ||||
} | } | ||||
public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight) | |||||
public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) | |||||
{ | { | ||||
var y_pred_shape = y_pred.shape; | |||||
var y_pred_rank = y_pred_shape.ndim; | |||||
if (y_true != null) | |||||
{ | |||||
var y_true_shape = y_true.shape; | |||||
var y_true_rank = y_true_shape.ndim; | |||||
if (y_true_rank > -1 && y_pred_rank > -1) | |||||
{ | |||||
if (y_pred_rank - y_true_rank != 1 || y_pred_shape[-1] == 1) | |||||
{ | |||||
(y_true, y_pred) = remove_squeezable_dimensions(y_true, y_pred); | |||||
} | |||||
} | |||||
} | |||||
if (sample_weight == null) | |||||
{ | |||||
return (y_pred, y_true); | |||||
} | |||||
var weights_shape = sample_weight.shape; | var weights_shape = sample_weight.shape; | ||||
var weights_rank = weights_shape.ndim; | var weights_rank = weights_shape.ndim; | ||||
if (weights_rank == 0) | if (weights_rank == 0) | ||||
return (y_pred, sample_weight); | return (y_pred, sample_weight); | ||||
if (y_pred_rank > -1 && weights_rank > -1) | |||||
{ | |||||
if (weights_rank - y_pred_rank == 1) | |||||
{ | |||||
sample_weight = tf.squeeze(sample_weight, -1); | |||||
} | |||||
else if (y_pred_rank - weights_rank == 1) | |||||
{ | |||||
sample_weight = tf.expand_dims(sample_weight, -1); | |||||
} | |||||
else | |||||
{ | |||||
return (y_pred, sample_weight); | |||||
} | |||||
} | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
public static (Tensor, Tensor) remove_squeezable_dimensions(Tensor labels, Tensor predictions, int expected_rank_diff = 0, string name = null) | |||||
{ | |||||
return (labels, predictions); | |||||
} | |||||
public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) | public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) | ||||
{ | { | ||||
if (reduction == ReductionV2.NONE) | if (reduction == ReductionV2.NONE) | ||||
@@ -14,6 +14,26 @@ namespace TensorFlowNET.Keras.UnitTest; | |||||
[TestClass] | [TestClass] | ||||
public class MetricsTest : EagerModeTestBase | public class MetricsTest : EagerModeTestBase | ||||
{ | { | ||||
/// <summary> | |||||
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void TopKCategoricalAccuracy() | |||||
{ | |||||
var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } }); | |||||
var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); | |||||
var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1); | |||||
m.update_state(y_true, y_pred); | |||||
var r = m.result().numpy(); | |||||
Assert.AreEqual(r, 0.5f); | |||||
m.reset_states(); | |||||
var weights = np.array(new[] { 0.7f, 0.3f }); | |||||
m.update_state(y_true, y_pred, sample_weight: weights); | |||||
r = m.result().numpy(); | |||||
Assert.AreEqual(r, 0.3f); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy | /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy | ||||
/// </summary> | /// </summary> | ||||