From b54cbaa772cdd791155e3aeac90b45c656868a22 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 20 Mar 2021 20:37:25 -0500 Subject: [PATCH] Fix binary_accuracy for keras. --- src/TensorFlowNET.Core/Tensorflow.Binding.csproj | 6 +++--- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 2 -- src/TensorFlowNET.Keras/BackendImpl.cs | 8 ++++++++ src/TensorFlowNET.Keras/Engine/MetricsContainer.cs | 4 +++- src/TensorFlowNET.Keras/Metrics/MetricsApi.cs | 11 ++++++++++- src/TensorFlowNET.Keras/Tensorflow.Keras.csproj | 8 ++++---- 6 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 7c6e3e00..92360a6d 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 2.2.0 - 0.40.0 + 0.40.1 8.0 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK @@ -19,7 +19,7 @@ Google's TensorFlow full binding in .NET Standard. Building, training and infering deep learning models. https://tensorflownet.readthedocs.io - 0.40.0.0 + 0.40.1.0 tf.net 0.20.x and above are based on tensorflow native 2.x. * Eager Mode is added finally. @@ -32,7 +32,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works. Keras API is a separate package released as TensorFlow.Keras. tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. - 0.40.0.0 + 0.40.1.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 25b97007..791306ca 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -596,8 +596,6 @@ would not be rank 1.", tensor.op.get_attr("axis"))); case TF_DataType.TF_STRING: return string.Join(string.Empty, nd.ToArray() .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); - case TF_DataType.TF_BOOL: - return nd.GetBoolean(0).ToString(); case TF_DataType.TF_VARIANT: case TF_DataType.TF_RESOURCE: return ""; diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index c82acce4..a9bcb8e3 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -137,6 +137,14 @@ namespace Tensorflow.Keras { _MANUAL_VAR_INIT = value; } + + public Tensor mean(Tensor x, int axis = -1, bool keepdims = false) + { + if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL) + x = math_ops.cast(x, TF_DataType.TF_FLOAT); + return math_ops.reduce_mean(x, axis: new[] { axis }, keepdims: false); + } + public GraphLearningPhase learning_phase() { var graph = tf.get_default_graph(); diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs index 3870c29b..39ba2a27 100644 --- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -68,7 +68,9 @@ namespace Tensorflow.Keras.Engine bool is_binary = y_p_last_dim == 1; bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; - if (is_sparse_categorical) + if (is_binary) + metric_obj = keras.metrics.binary_accuracy; + else if (is_sparse_categorical) metric_obj = keras.metrics.sparse_categorical_accuracy; else metric_obj = keras.metrics.categorical_accuracy; diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index 105b8b3c..f165a347 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -1,7 +1,16 @@ -namespace Tensorflow.Keras.Metrics +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Metrics { public class MetricsApi { + public Tensor binary_accuracy(Tensor y_true, Tensor y_pred) + { + float threshold = 0.5f; + y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype); + return keras.backend.mean(math_ops.equal(y_true, y_pred), axis: -1); + } + public Tensor categorical_accuracy(Tensor y_true, Tensor y_pred) { var eql = math_ops.equal(math_ops.argmax(y_true, -1), math_ops.argmax(y_pred, -1)); diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 0c50a5a1..6d246126 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -6,10 +6,10 @@ 8.0 Tensorflow.Keras AnyCPU;x64 - 0.5.0 + 0.5.1 Haiping Chen Keras for .NET - Apache 2.0, Haiping Chen 2020 + Apache 2.0, Haiping Chen 2021 TensorFlow.Keras https://github.com/SciSharp/TensorFlow.NET https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 @@ -35,8 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac Git true Open.snk - 0.5.0.0 - 0.5.0.0 + 0.5.1.0 + 0.5.1.0 LICENSE