diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
index 10c09d99..32f64ec3 100644
--- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
@@ -54,6 +54,12 @@ namespace Tensorflow
public Tensor global_norm(Tensor[] t_list, string name = null)
=> clip_ops.global_norm(t_list, name: name);
+ public Tensor l2_normalize(Tensor x,
+ int axis = 0,
+ float epsilon = 1e-12f,
+ string name = null)
+ => nn_impl.l2_normalize(x, axis: axis, epsilon: constant_op.constant(epsilon), name: name);
+
public Tensor lstsq(Tensor matrix, Tensor rhs,
NDArray l2_regularizer = null, bool fast = true, string name = null)
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);
diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
index 75946303..e4575620 100644
--- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
@@ -31,6 +31,13 @@ public interface IMetricsApi
///
Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5);
+ ///
+ /// Calculates how often predictions equal labels.
+ ///
+ ///
+ IMetricFunc Accuracy(string name = "accuracy",
+ TF_DataType dtype = TF_DataType.TF_FLOAT);
+
///
/// Calculates how often predictions match binary labels.
///
@@ -56,6 +63,14 @@ public interface IMetricsApi
IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy",
TF_DataType dtype = TF_DataType.TF_FLOAT);
+ ///
+ /// Computes the cosine similarity between the labels and predictions.
+ ///
+ ///
+ IMetricFunc CosineSimilarity(string name = "cosine_similarity",
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ Axis? axis = null);
+
///
/// Computes how often targets are in the top K predictions.
///
diff --git a/src/TensorFlowNET.Keras/Metrics/Accuracy.cs b/src/TensorFlowNET.Keras/Metrics/Accuracy.cs
new file mode 100644
index 00000000..93a72467
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/Accuracy.cs
@@ -0,0 +1,11 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class Accuracy : MeanMetricWrapper
+{
+ public Accuracy(string name = "accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ : base((yt, yp) => metrics_utils.accuracy(yt, yp),
+ name: name,
+ dtype: dtype)
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs b/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs
new file mode 100644
index 00000000..2a26bcdf
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs
@@ -0,0 +1,11 @@
+namespace Tensorflow.Keras.Metrics;
+
+public class CosineSimilarity : MeanMetricWrapper
+{
+ public CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null)
+ : base((yt, yp) => metrics_utils.cosine_similarity(yt, yp, axis: axis ?? -1),
+ name: name,
+ dtype: dtype)
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
index fcd0516b..e207d27d 100644
--- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
@@ -71,6 +71,9 @@
);
}
+ public IMetricFunc Accuracy(string name = "accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
+ => new Accuracy(name: name, dtype: dtype);
+
public IMetricFunc BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 5)
=> new BinaryAccuracy();
@@ -80,6 +83,9 @@
public IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, float label_smoothing = 0, Axis? axis = null)
=> new CategoricalCrossentropy();
+ public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null)
+ => new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1);
+
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
index 0f523e7e..f4bfc3da 100644
--- a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
+++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
@@ -4,12 +4,26 @@ namespace Tensorflow.Keras.Metrics;
public class metrics_utils
{
+ public static Tensor accuracy(Tensor y_true, Tensor y_pred)
+ {
+ if (y_true.dtype != y_pred.dtype)
+ y_pred = tf.cast(y_pred, y_true.dtype);
+ return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx());
+ }
+
public static Tensor binary_matches(Tensor y_true, Tensor y_pred, float threshold = 0.5f)
{
y_pred = tf.cast(y_pred > threshold, y_pred.dtype);
return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx());
}
+ public static Tensor cosine_similarity(Tensor y_true, Tensor y_pred, Axis? axis = null)
+ {
+ y_true = tf.linalg.l2_normalize(y_true, axis: axis ?? -1);
+ y_pred = tf.linalg.l2_normalize(y_pred, axis: axis ?? -1);
+ return tf.reduce_sum(y_true * y_pred, axis: axis ?? -1);
+ }
+
///
/// Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch.
///
diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
index 9389af96..90be51bd 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs
@@ -14,6 +14,26 @@ namespace TensorFlowNET.Keras.UnitTest;
[TestClass]
public class MetricsTest : EagerModeTestBase
{
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Accuracy
+ ///
+ [TestMethod]
+ public void Accuracy()
+ {
+ var y_true = np.array(new[,] { { 1 }, { 2 }, { 3 }, { 4 } });
+ var y_pred = np.array(new[,] { { 0f }, { 2f }, { 3f }, { 4f } });
+ var m = tf.keras.metrics.Accuracy();
+ m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 0.75f);
+
+ m.reset_states();
+ var weights = np.array(new[] { 1f, 1f, 0f, 0f });
+ m.update_state(y_true, y_pred, sample_weight: weights);
+ r = m.result().numpy();
+ Assert.AreEqual(r, 0.5f);
+ }
+
///
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy
///
@@ -23,14 +43,14 @@ public class MetricsTest : EagerModeTestBase
var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } });
var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } });
var m = tf.keras.metrics.BinaryAccuracy();
- /*m.update_state(y_true, y_pred);
+ m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, 0.75f);
- m.reset_states();*/
+ m.reset_states();
var weights = np.array(new[] { 1f, 0f, 0f, 1f });
m.update_state(y_true, y_pred, sample_weight: weights);
- var r = m.result().numpy();
+ r = m.result().numpy();
Assert.AreEqual(r, 0.5f);
}
@@ -74,6 +94,26 @@ public class MetricsTest : EagerModeTestBase
Assert.AreEqual(r, 1.6271976f);
}
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CosineSimilarity
+ ///
+ [TestMethod]
+ public void CosineSimilarity()
+ {
+ var y_true = np.array(new[,] { { 0, 1 }, { 1, 1 } });
+ var y_pred = np.array(new[,] { { 1f, 0f }, { 1f, 1f } });
+ var m = tf.keras.metrics.CosineSimilarity(axis: 1);
+ m.update_state(y_true, y_pred);
+ var r = m.result().numpy();
+ Assert.AreEqual(r, 0.49999997f);
+
+ m.reset_states();
+ var weights = np.array(new[] { 0.3f, 0.7f });
+ m.update_state(y_true, y_pred, sample_weight: weights);
+ r = m.result().numpy();
+ Assert.AreEqual(r, 0.6999999f);
+ }
+
///
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
///