Browse Source

Fix binary_accuracy for keras.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
b54cbaa772
6 changed files with 28 additions and 11 deletions
  1. +3
    -3
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  2. +0
    -2
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  3. +8
    -0
      src/TensorFlowNET.Keras/BackendImpl.cs
  4. +3
    -1
      src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
  5. +10
    -1
      src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
  6. +4
    -4
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

+ 3
- 3
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.40.0</Version>
<Version>0.40.1</Version>
<LangVersion>8.0</LangVersion>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company>
@@ -19,7 +19,7 @@
<Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.40.0.0</AssemblyVersion>
<AssemblyVersion>0.40.1.0</AssemblyVersion>
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x.

* Eager Mode is added finally.
@@ -32,7 +32,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works.
Keras API is a separate package released as TensorFlow.Keras.

tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.</PackageReleaseNotes>
<FileVersion>0.40.0.0</FileVersion>
<FileVersion>0.40.1.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>


+ 0
- 2
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -596,8 +596,6 @@ would not be rank 1.", tensor.op.get_attr("axis")));
case TF_DataType.TF_STRING:
return string.Join(string.Empty, nd.ToArray<byte>()
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
case TF_DataType.TF_BOOL:
return nd.GetBoolean(0).ToString();
case TF_DataType.TF_VARIANT:
case TF_DataType.TF_RESOURCE:
return "<unprintable>";


+ 8
- 0
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -137,6 +137,14 @@ namespace Tensorflow.Keras
{
_MANUAL_VAR_INIT = value;
}

public Tensor mean(Tensor x, int axis = -1, bool keepdims = false)
{
if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL)
x = math_ops.cast(x, TF_DataType.TF_FLOAT);
return math_ops.reduce_mean(x, axis: new[] { axis }, keepdims: false);
}

public GraphLearningPhase learning_phase()
{
var graph = tf.get_default_graph();


+ 3
- 1
src/TensorFlowNET.Keras/Engine/MetricsContainer.cs View File

@@ -68,7 +68,9 @@ namespace Tensorflow.Keras.Engine
bool is_binary = y_p_last_dim == 1;
bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1;

if (is_sparse_categorical)
if (is_binary)
metric_obj = keras.metrics.binary_accuracy;
else if (is_sparse_categorical)
metric_obj = keras.metrics.sparse_categorical_accuracy;
else
metric_obj = keras.metrics.categorical_accuracy;


+ 10
- 1
src/TensorFlowNET.Keras/Metrics/MetricsApi.cs View File

@@ -1,7 +1,16 @@
namespace Tensorflow.Keras.Metrics
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Metrics
{
public class MetricsApi
{
public Tensor binary_accuracy(Tensor y_true, Tensor y_pred)
{
float threshold = 0.5f;
y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype);
return keras.backend.mean(math_ops.equal(y_true, y_pred), axis: -1);
}

public Tensor categorical_accuracy(Tensor y_true, Tensor y_pred)
{
var eql = math_ops.equal(math_ops.argmax(y_true, -1), math_ops.argmax(y_pred, -1));


+ 4
- 4
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -6,10 +6,10 @@
<LangVersion>8.0</LangVersion>
<RootNamespace>Tensorflow.Keras</RootNamespace>
<Platforms>AnyCPU;x64</Platforms>
<Version>0.5.0</Version>
<Version>0.5.1</Version>
<Authors>Haiping Chen</Authors>
<Product>Keras for .NET</Product>
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
<Copyright>Apache 2.0, Haiping Chen 2021</Copyright>
<PackageId>TensorFlow.Keras</PackageId>
<PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl>
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
@@ -35,8 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<AssemblyVersion>0.5.0.0</AssemblyVersion>
<FileVersion>0.5.0.0</FileVersion>
<AssemblyVersion>0.5.1.0</AssemblyVersion>
<FileVersion>0.5.1.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
</PropertyGroup>



Loading…
Cancel
Save