diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs index 5d08cc78..dbe4ac3f 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs @@ -22,6 +22,20 @@ public interface IMetricsApi /// Sparse categorical accuracy values. Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred); + /// + /// Computes the sparse categorical crossentropy loss. + /// + /// + /// + /// + /// + /// + /// + Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, + bool from_logits = false, + int? ignore_class = null, + Axis? axis = null); + /// /// Computes how often targets are in the top `K` predictions. /// @@ -56,6 +70,16 @@ public interface IMetricsApi float label_smoothing = 0f, Axis? axis = null); + /// + /// Computes the crossentropy metric between the labels and predictions. + /// + /// + IMetricFunc SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + bool from_logits = false, + int? ignore_class = null, + Axis? axis = null); + /// /// Computes the crossentropy metric between the labels and predictions. /// @@ -63,6 +87,13 @@ public interface IMetricsApi IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); + /// + /// Calculates how often predictions match integer labels. + /// + /// + IMetricFunc SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT); + /// /// Computes the cosine similarity between the labels and predictions. /// @@ -114,6 +145,15 @@ public interface IMetricsApi string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); + /// + /// Computes how often integer targets are in the top K predictions. + /// + /// + /// + IMetricFunc SparseTopKCategoricalAccuracy(int k = 5, + string name = "sparse_top_k_categorical_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT); + /// /// Computes the precision of the predictions with respect to the labels. /// diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 0c9da015..c49fc140 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -276,6 +276,64 @@ namespace Tensorflow.Keras return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis)); } + public Tensor sparse_categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1, int? ignore_class = null) + { + target = tf.cast(target, tf.int64); + if (!from_logits) + { + var epsilon_ = constant_op.constant(epsilon(), output.dtype.as_base_dtype()); + output = tf.clip_by_value(output, epsilon_, 1 - epsilon_); + output = tf.math.log(output); + } + var output_rank = output.shape.ndim; + if (output_rank > -1) + { + axis = Math.Abs(axis) % output_rank; + if (axis != output_rank - 1) + { + /*var permutation = list( + itertools.chain( + range(axis), range(axis + 1, output_rank), [axis] + ) + ); + output = tf.transpose(output, perm: permutation);*/ + throw new NotImplementedException(""); + } + + } + + var output_shape = tf.shape(output); + var target_rank = target.shape.ndim; + var update_shape = target_rank > -1 && output_rank > -1 && target_rank != output_rank - 1; + if (update_shape) + { + /*var target = flatten(target); + output = tf.reshape(output, [-1, output_shape[-1]]);*/ + throw new NotImplementedException(""); + } + + if (ignore_class.HasValue) + { + throw new NotImplementedException(""); + } + + var res = tf.nn.sparse_softmax_cross_entropy_with_logits(labels: target, logits: output); + + if (ignore_class.HasValue) + { + throw new NotImplementedException(""); + } + + if (update_shape && output_rank >= 3) + { + // If our output includes timesteps or + // spatial dimensions we need to reshape + res = tf.reshape(res, output_shape[":-1"]); + } + + return res; + } + public Tensor binary_crossentropy(Tensor target, Tensor output, bool from_logits = false) { if (from_logits) diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index 585fefae..e3881cf1 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -27,6 +27,11 @@ return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis); } + public Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, int? ignore_class = null, Axis? axis = null) + { + return keras.backend.sparse_categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis ?? -1, ignore_class: ignore_class); + } + /// /// Calculates how often predictions matches integer labels. /// @@ -103,5 +108,14 @@ public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) => new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); + + public IMetricFunc SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, int? ignore_class = null, Axis? axis = null) + => new SparseCategoricalCrossentropy(name: name, dtype: dtype, from_logits: from_logits, ignore_class: ignore_class, axis: axis ?? -1); + + public IMetricFunc SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new SparseTopKCategoricalAccuracy(k: k, name: name, dtype: dtype); + + public IMetricFunc SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new SparseCategoricalAccuracy(name: name, dtype: dtype); } } diff --git a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs new file mode 100644 index 00000000..6cad9aac --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Metrics; + +public class SparseCategoricalAccuracy : MeanMetricWrapper +{ + public SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.sparse_categorical_matches(yt, yp), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs new file mode 100644 index 00000000..d517da91 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs @@ -0,0 +1,16 @@ +namespace Tensorflow.Keras.Metrics; + +public class SparseCategoricalCrossentropy : MeanMetricWrapper +{ + public SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + bool from_logits = false, + int? ignore_class = null, + Axis? axis = null) + : base((yt, yp) => keras.metrics.sparse_categorical_crossentropy( + yt, yp, from_logits: from_logits, ignore_class: ignore_class, axis: axis ?? -1), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs new file mode 100644 index 00000000..eb6d9f3b --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Metrics; + +public class SparseTopKCategoricalAccuracy : MeanMetricWrapper +{ + public SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches(yt, yp, k), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs index 269bb1fb..be6a49ec 100644 --- a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs +++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs @@ -73,7 +73,7 @@ public class metrics_utils y_true = tf.squeeze(y_true, new Shape(-1)); } y_pred = tf.math.argmax(y_pred, axis: -1); - + y_pred = tf.cast(y_pred, y_true.dtype); var matches = tf.cast( tf.equal(y_true, y_pred), dtype: keras.backend.floatx() diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs index 267cef81..04810db3 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs @@ -74,6 +74,26 @@ public class MetricsTest : EagerModeTestBase Assert.AreEqual(r, 0.3f); } + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy + /// + [TestMethod] + public void SparseCategoricalAccuracy() + { + var y_true = np.array(new[] { 2, 1 }); + var y_pred = np.array(new[,] { { 0.1f, 0.6f, 0.3f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.SparseCategoricalAccuracy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + + m.reset_states(); + var weights = np.array(new[] { 0.7f, 0.3f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.3f); + } + /// /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy /// @@ -94,6 +114,20 @@ public class MetricsTest : EagerModeTestBase Assert.AreEqual(r, 1.6271976f); } + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalCrossentropy + /// + [TestMethod] + public void SparseCategoricalCrossentropy() + { + var y_true = np.array(new[] { 1, 2 }); + var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } }); + var m = tf.keras.metrics.SparseCategoricalCrossentropy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 1.1769392f); + } + /// /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CosineSimilarity /// @@ -207,6 +241,26 @@ public class MetricsTest : EagerModeTestBase Assert.AreEqual(r, 0.3f); } + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseTopKCategoricalAccuracy + /// + [TestMethod] + public void SparseTopKCategoricalAccuracy() + { + var y_true = np.array(new[] { 2, 1 }); + var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k: 1); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + + m.reset_states(); + var weights = np.array(new[] { 0.7f, 0.3f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.3f); + } + /// /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy ///