diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index 7c6e3e00..92360a6d 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -5,7 +5,7 @@
TensorFlow.NET
Tensorflow
2.2.0
- 0.40.0
+ 0.40.1
8.0
Haiping Chen, Meinrad Recheis, Eli Belash
SciSharp STACK
@@ -19,7 +19,7 @@
Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io
- 0.40.0.0
+ 0.40.1.0
tf.net 0.20.x and above are based on tensorflow native 2.x.
* Eager Mode is added finally.
@@ -32,7 +32,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works.
Keras API is a separate package released as TensorFlow.Keras.
tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.
- 0.40.0.0
+ 0.40.1.0
LICENSE
true
true
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index 25b97007..791306ca 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -596,8 +596,6 @@ would not be rank 1.", tensor.op.get_attr("axis")));
case TF_DataType.TF_STRING:
return string.Join(string.Empty, nd.ToArray()
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
- case TF_DataType.TF_BOOL:
- return nd.GetBoolean(0).ToString();
case TF_DataType.TF_VARIANT:
case TF_DataType.TF_RESOURCE:
return "";
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index c82acce4..a9bcb8e3 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -137,6 +137,14 @@ namespace Tensorflow.Keras
{
_MANUAL_VAR_INIT = value;
}
+
+ public Tensor mean(Tensor x, int axis = -1, bool keepdims = false)
+ {
+ if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL)
+ x = math_ops.cast(x, TF_DataType.TF_FLOAT);
+ return math_ops.reduce_mean(x, axis: new[] { axis }, keepdims: false);
+ }
+
public GraphLearningPhase learning_phase()
{
var graph = tf.get_default_graph();
diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
index 3870c29b..39ba2a27 100644
--- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
+++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
@@ -68,7 +68,9 @@ namespace Tensorflow.Keras.Engine
bool is_binary = y_p_last_dim == 1;
bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1;
- if (is_sparse_categorical)
+ if (is_binary)
+ metric_obj = keras.metrics.binary_accuracy;
+ else if (is_sparse_categorical)
metric_obj = keras.metrics.sparse_categorical_accuracy;
else
metric_obj = keras.metrics.categorical_accuracy;
diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
index 105b8b3c..f165a347 100644
--- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
+++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
@@ -1,7 +1,16 @@
-namespace Tensorflow.Keras.Metrics
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.Metrics
{
public class MetricsApi
{
+ public Tensor binary_accuracy(Tensor y_true, Tensor y_pred)
+ {
+ float threshold = 0.5f;
+ y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype);
+ return keras.backend.mean(math_ops.equal(y_true, y_pred), axis: -1);
+ }
+
public Tensor categorical_accuracy(Tensor y_true, Tensor y_pred)
{
var eql = math_ops.equal(math_ops.argmax(y_true, -1), math_ops.argmax(y_pred, -1));
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index 0c50a5a1..6d246126 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -6,10 +6,10 @@
8.0
Tensorflow.Keras
AnyCPU;x64
- 0.5.0
+ 0.5.1
Haiping Chen
Keras for .NET
- Apache 2.0, Haiping Chen 2020
+ Apache 2.0, Haiping Chen 2021
TensorFlow.Keras
https://github.com/SciSharp/TensorFlow.NET
https://avatars3.githubusercontent.com/u/44989469?s=200&v=4
@@ -35,8 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
Git
true
Open.snk
- 0.5.0.0
- 0.5.0.0
+ 0.5.1.0
+ 0.5.1.0
LICENSE