diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
index 5d08cc78..dbe4ac3f 100644
--- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
@@ -22,6 +22,20 @@ public interface IMetricsApi
/// Sparse categorical accuracy values.
Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred);
+ ///
+ /// Computes the sparse categorical crossentropy loss.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred,
+ bool from_logits = false,
+ int? ignore_class = null,
+ Axis? axis = null);
+
///
/// Computes how often targets are in the top `K` predictions.
///
@@ -56,6 +70,16 @@ public interface IMetricsApi
float label_smoothing = 0f,
Axis? axis = null);
+ ///
+ /// Computes the crossentropy metric between the labels and predictions.
+ ///
+ ///
+ IMetricFunc SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ bool from_logits = false,
+ int? ignore_class = null,
+ Axis? axis = null);
+
///
/// Computes the crossentropy metric between the labels and predictions.
///
@@ -63,6 +87,13 @@ public interface IMetricsApi
IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy",
TF_DataType dtype = TF_DataType.TF_FLOAT);
+ ///
+ /// Calculates how often predictions match integer labels.
+ ///
+ ///
+ IMetricFunc SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT);
+
///
/// Computes the cosine similarity between the labels and predictions.
///
@@ -114,6 +145,15 @@ public interface IMetricsApi
string name = "top_k_categorical_accuracy",
TF_DataType dtype = TF_DataType.TF_FLOAT);
+ ///
+ /// Computes how often integer targets are in the top K predictions.
+ ///
+ ///
+ ///
+ IMetricFunc SparseTopKCategoricalAccuracy(int k = 5,
+ string name = "sparse_top_k_categorical_accuracy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT);
+
///
/// Computes the precision of the predictions with respect to the labels.
///
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index 0c9da015..c49fc140 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -276,6 +276,64 @@ namespace Tensorflow.Keras
return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis));
}
+ public Tensor sparse_categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1, int? ignore_class = null)
+ {
+ target = tf.cast(target, tf.int64);
+ if (!from_logits)
+ {
+ var epsilon_ = constant_op.constant(epsilon(), output.dtype.as_base_dtype());
+ output = tf.clip_by_value(output, epsilon_, 1 - epsilon_);
+ output = tf.math.log(output);
+ }
+ var output_rank = output.shape.ndim;
+ if (output_rank > -1)
+ {
+ axis = Math.Abs(axis) % output_rank;
+ if (axis != output_rank - 1)
+ {
+ /*var permutation = list(
+ itertools.chain(
+ range(axis), range(axis + 1, output_rank), [axis]
+ )
+ );
+ output = tf.transpose(output, perm: permutation);*/
+ throw new NotImplementedException("");
+ }
+
+ }
+
+ var output_shape = tf.shape(output);
+ var target_rank = target.shape.ndim;
+ var update_shape = target_rank > -1 && output_rank > -1 && target_rank != output_rank - 1;
+ if (update_shape)
+ {
+ /*var target = flatten(target);
+ output = tf.reshape(output, [-1, output_shape[-1]]);*/
+ throw new NotImplementedException("");
+ }
+
+ if (ignore_class.HasValue)
+ {
+ throw new NotImplementedException("");
+ }
+
+ var res = tf.nn.sparse_softmax_cross_entropy_with_logits(labels: target, logits: output);
+
+ if (ignore_class.HasValue)
+ {
+ throw new NotImplementedException("");
+ }
+
+ if (update_shape && output_rank >= 3)
+ {
+ // If our output includes timesteps or
+ // spatial dimensions we need to reshape
+ res = tf.reshape(res, output_shape[":-1"]);
+ }
+
+ return res;
+ }
+
public Tensor binary_crossentropy(Tensor target, Tensor output, bool from_logits = false)
{
if (from_logits)
diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
index 585fefae..e3881cf1 100644
--- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
@@ -27,6 +27,11 @@
return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis);
}
+ public Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, int? ignore_class = null, Axis? axis = null)
+ {
+ return keras.backend.sparse_categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis ?? -1, ignore_class: ignore_class);
+ }
+
///
/// Calculates how often predictions matches integer labels.
///
@@ -103,5 +108,14 @@
public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype);
+
+ public IMetricFunc SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, int? ignore_class = null, Axis? axis = null)
+ => new SparseCategoricalCrossentropy(name: name, dtype: dtype, from_logits: from_logits, ignore_class: ignore_class, axis: axis ?? -1);
+
+ public IMetricFunc SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ => new SparseTopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
+
+ public IMetricFunc SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ => new SparseCategoricalAccuracy(name: name, dtype: dtype);
}
}
diff --git a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs
new file mode 100644
index 00000000..6cad9aac
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs
@@ -0,0 +1,11 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class SparseCategoricalAccuracy : MeanMetricWrapper
+{
+ public SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ : base((yt, yp) => metrics_utils.sparse_categorical_matches(yt, yp),
+ name: name,
+ dtype: dtype)
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs
new file mode 100644
index 00000000..d517da91
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs
@@ -0,0 +1,16 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class SparseCategoricalCrossentropy : MeanMetricWrapper
+{
+ public SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ bool from_logits = false,
+ int? ignore_class = null,
+ Axis? axis = null)
+ : base((yt, yp) => keras.metrics.sparse_categorical_crossentropy(
+ yt, yp, from_logits: from_logits, ignore_class: ignore_class, axis: axis ?? -1),
+ name: name,
+ dtype: dtype)
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs
new file mode 100644
index 00000000..eb6d9f3b
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs
@@ -0,0 +1,11 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class SparseTopKCategoricalAccuracy : MeanMetricWrapper
+{
+ public SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ : base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches(yt, yp, k),
+ name: name,
+ dtype: dtype)
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
index 269bb1fb..be6a49ec 100644
--- a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
+++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
@@ -73,7 +73,7 @@ public class metrics_utils
y_true = tf.squeeze(y_true, new Shape(-1));
}
y_pred = tf.math.argmax(y_pred, axis: -1);
-
+ y_pred = tf.cast(y_pred, y_true.dtype);
var matches = tf.cast(
tf.equal(y_true, y_pred),
dtype: keras.backend.floatx()
diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
index 267cef81..04810db3 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
@@ -74,6 +74,26 @@ public class MetricsTest : EagerModeTestBase
Assert.AreEqual(r, 0.3f);
}
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy
+ ///
+ [TestMethod]
+ public void SparseCategoricalAccuracy()
+ {
+ var y_true = np.array(new[] { 2, 1 });
+ var y_pred = np.array(new[,] { { 0.1f, 0.6f, 0.3f }, { 0.05f, 0.95f, 0f } });
+ var m = tf.keras.metrics.SparseCategoricalAccuracy();
+ m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 0.5f);
+
+ m.reset_states();
+ var weights = np.array(new[] { 0.7f, 0.3f });
+ m.update_state(y_true, y_pred, sample_weight: weights);
+ r = m.result().numpy();
+ Assert.AreEqual(r, 0.3f);
+ }
+
///
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy
///
@@ -94,6 +114,20 @@ public class MetricsTest : EagerModeTestBase
Assert.AreEqual(r, 1.6271976f);
}
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalCrossentropy
+ ///
+ [TestMethod]
+ public void SparseCategoricalCrossentropy()
+ {
+ var y_true = np.array(new[] { 1, 2 });
+ var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } });
+ var m = tf.keras.metrics.SparseCategoricalCrossentropy();
+ m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 1.1769392f);
+ }
+
///
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CosineSimilarity
///
@@ -207,6 +241,26 @@ public class MetricsTest : EagerModeTestBase
Assert.AreEqual(r, 0.3f);
}
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseTopKCategoricalAccuracy
+ ///
+ [TestMethod]
+ public void SparseTopKCategoricalAccuracy()
+ {
+ var y_true = np.array(new[] { 2, 1 });
+ var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
+ var m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k: 1);
+ m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 0.5f);
+
+ m.reset_states();
+ var weights = np.array(new[] { 0.7f, 0.3f });
+ m.update_state(y_true, y_pred, sample_weight: weights);
+ r = m.result().numpy();
+ Assert.AreEqual(r, 0.3f);
+ }
+
///
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy
///