Browse Source

Add metrics of F1Score and FBetaScore.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
067c1ff92a
8 changed files with 214 additions and 6 deletions
  1. +9
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +21
    -0
      src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  4. +13
    -0
      src/TensorFlowNET.Keras/Metrics/F1Score.cs
  5. +131
    -0
      src/TensorFlowNET.Keras/Metrics/FBetaScore.cs
  6. +6
    -0
      src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
  7. +3
    -3
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  8. +28
    -0
      test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs

+ 9
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -36,6 +36,15 @@ namespace Tensorflow
public Tensor erf(Tensor x, string name = null) public Tensor erf(Tensor x, string name = null)
=> math_ops.erf(x, name); => math_ops.erf(x, name);


public Tensor multiply(Tensor x, Tensor y, string name = null)
=> math_ops.multiply(x, y, name: name);

public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
=> math_ops.div_no_nan(a, b);

public Tensor square(Tensor x, string name = null)
=> math_ops.square(x, name: name);

public Tensor sum(Tensor x, Axis? axis = null, string name = null) public Tensor sum(Tensor x, Axis? axis = null, string name = null)
=> math_ops.reduce_sum(x, axis: axis, name: name); => math_ops.reduce_sum(x, axis: axis, name: name);




+ 21
- 0
src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs View File

@@ -71,6 +71,27 @@ public interface IMetricsApi
TF_DataType dtype = TF_DataType.TF_FLOAT, TF_DataType dtype = TF_DataType.TF_FLOAT,
Axis? axis = null); Axis? axis = null);


/// <summary>
/// Computes F-1 Score.
/// </summary>
/// <returns></returns>
IMetricFunc F1Score(int num_classes,
string? average = null,
float threshold = -1f,
string name = "fbeta_score",
TF_DataType dtype = TF_DataType.TF_FLOAT);

/// <summary>
/// Computes F-Beta score.
/// </summary>
/// <returns></returns>
IMetricFunc FBetaScore(int num_classes,
string? average = null,
float beta = 0.1f,
float threshold = -1f,
string name = "fbeta_score",
TF_DataType dtype = TF_DataType.TF_FLOAT);
/// <summary> /// <summary>
/// Computes how often targets are in the top K predictions. /// Computes how often targets are in the top K predictions.
/// </summary> /// </summary>


+ 3
- 3
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>Tensorflow.Binding</AssemblyName> <AssemblyName>Tensorflow.Binding</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.10.0</TargetTensorFlow> <TargetTensorFlow>2.10.0</TargetTensorFlow>
<Version>0.100.3</Version>
<Version>0.100.4</Version>
<LangVersion>10.0</LangVersion> <LangVersion>10.0</LangVersion>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
@@ -20,7 +20,7 @@
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models. Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description> https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.100.3.0</AssemblyVersion>
<AssemblyVersion>0.100.4.0</AssemblyVersion>
<PackageReleaseNotes> <PackageReleaseNotes>
tf.net 0.100.x and above are based on tensorflow native 2.10.0 tf.net 0.100.x and above are based on tensorflow native 2.10.0


@@ -38,7 +38,7 @@ https://tensorflownet.readthedocs.io</Description>
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library. tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library. tf.net 0.10x.x aligns with TensorFlow v2.10.x native library.
</PackageReleaseNotes> </PackageReleaseNotes>
<FileVersion>0.100.3.0</FileVersion>
<FileVersion>0.100.4.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>


+ 13
- 0
src/TensorFlowNET.Keras/Metrics/F1Score.cs View File

@@ -0,0 +1,13 @@
namespace Tensorflow.Keras.Metrics;

public class F1Score : FBetaScore
{
public F1Score(int num_classes,
string? average = null,
float? threshold = -1f,
string name = "f1_score",
TF_DataType dtype = TF_DataType.TF_FLOAT)
: base(num_classes, average: average, threshold: threshold, beta: 1f, name: name, dtype: dtype)
{
}
}

+ 131
- 0
src/TensorFlowNET.Keras/Metrics/FBetaScore.cs View File

@@ -0,0 +1,131 @@
namespace Tensorflow.Keras.Metrics;

public class FBetaScore : Metric
{
int _num_classes;
string? _average;
Tensor _beta;
Tensor _threshold;
Axis _axis;
int[] _init_shape;

IVariableV1 true_positives;
IVariableV1 false_positives;
IVariableV1 false_negatives;
IVariableV1 weights_intermediate;

public FBetaScore(int num_classes,
string? average = null,
float beta = 0.1f,
float? threshold = -1f,
string name = "fbeta_score",
TF_DataType dtype = TF_DataType.TF_FLOAT)
: base(name: name, dtype: dtype)
{
_num_classes = num_classes;
_average = average;
_beta = constant_op.constant(beta);
_dtype = dtype;

if (threshold.HasValue)
{
_threshold = constant_op.constant(threshold);
}
_init_shape = new int[0];

if (average != "micro")
{
_axis = 0;
_init_shape = new int[] { num_classes };
}

true_positives = add_weight("true_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
false_positives = add_weight("false_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
false_negatives = add_weight("false_negatives", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
weights_intermediate = add_weight("weights_intermediate", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
}

public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
if (_threshold == null)
{
_threshold = tf.reduce_max(y_pred, axis: -1, keepdims: true);
// make sure [0, 0, 0] doesn't become [1, 1, 1]
// Use abs(x) > eps, instead of x != 0 to check for zero
y_pred = tf.logical_and(y_pred >= _threshold, tf.abs(y_pred) > 1e-12);
}
else
{
y_pred = y_pred > _threshold;
}

y_true = tf.cast(y_true, _dtype);
y_pred = tf.cast(y_pred, _dtype);

true_positives.assign_add(_weighted_sum(y_pred * y_true, sample_weight));
false_positives.assign_add(
_weighted_sum(y_pred * (1 - y_true), sample_weight)
);
false_negatives.assign_add(
_weighted_sum((1 - y_pred) * y_true, sample_weight)
);
weights_intermediate.assign_add(_weighted_sum(y_true, sample_weight));

return weights_intermediate.AsTensor();
}

Tensor _weighted_sum(Tensor val, Tensor? sample_weight = null)
{
if (sample_weight != null)
{
val = tf.math.multiply(val, tf.expand_dims(sample_weight, 1));
}
return tf.reduce_sum(val, axis: _axis);
}

public override Tensor result()
{
var precision = tf.math.divide_no_nan(
true_positives.AsTensor(), true_positives.AsTensor() + false_positives.AsTensor()
);
var recall = tf.math.divide_no_nan(
true_positives.AsTensor(), true_positives.AsTensor() + false_negatives.AsTensor()
);

var mul_value = precision * recall;
var add_value = (tf.math.square(_beta) * precision) + recall;
var mean = tf.math.divide_no_nan(mul_value, add_value);
var f1_score = mean * (1 + tf.math.square(_beta));

Tensor weights;
if (_average == "weighted")
{
weights = tf.math.divide_no_nan(
weights_intermediate.AsTensor(), tf.reduce_sum(weights_intermediate.AsTensor())
);
f1_score = tf.reduce_sum(f1_score * weights);
}
// micro, macro
else if (_average != null)
{
f1_score = tf.reduce_mean(f1_score);
}

return f1_score;
}

public override void reset_states()
{
var reset_value = np.zeros(_init_shape, dtype: _dtype);
keras.backend.batch_set_value(
new List<(IVariableV1, NDArray)>
{
(true_positives, reset_value),
(false_positives, reset_value),
(false_negatives, reset_value),
(weights_intermediate, reset_value)
});
}
}

+ 6
- 0
src/TensorFlowNET.Keras/Metrics/MetricsApi.cs View File

@@ -86,6 +86,12 @@
public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null) public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null)
=> new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1); => new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1);


public IMetricFunc F1Score(int num_classes, string? average = null, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new F1Score(num_classes, average: average, threshold: threshold, name: name, dtype: dtype);

public IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1F, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new FBetaScore(num_classes, average: average,beta: beta, threshold: threshold, name: name, dtype: dtype);

public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);




+ 3
- 3
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -7,7 +7,7 @@
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<RootNamespace>Tensorflow.Keras</RootNamespace> <RootNamespace>Tensorflow.Keras</RootNamespace>
<Platforms>AnyCPU;x64</Platforms> <Platforms>AnyCPU;x64</Platforms>
<Version>0.10.3</Version>
<Version>0.10.4</Version>
<Authors>Haiping Chen</Authors> <Authors>Haiping Chen</Authors>
<Product>Keras for .NET</Product> <Product>Keras for .NET</Product>
<Copyright>Apache 2.0, Haiping Chen 2023</Copyright> <Copyright>Apache 2.0, Haiping Chen 2023</Copyright>
@@ -37,8 +37,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<RepositoryType>Git</RepositoryType> <RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<AssemblyVersion>0.10.3.0</AssemblyVersion>
<FileVersion>0.10.3.0</FileVersion>
<AssemblyVersion>0.10.4.0</AssemblyVersion>
<FileVersion>0.10.4.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<Configurations>Debug;Release;GPU</Configurations> <Configurations>Debug;Release;GPU</Configurations>
</PropertyGroup> </PropertyGroup>


+ 28
- 0
test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs View File

@@ -114,6 +114,34 @@ public class MetricsTest : EagerModeTestBase
Assert.AreEqual(r, 0.6999999f); Assert.AreEqual(r, 0.6999999f);
} }


/// <summary>
/// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score
/// </summary>
[TestMethod]
public void F1Score()
{
var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
var m = tf.keras.metrics.F1Score(num_classes: 3, threshold: 0.5f);
m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, new[] { 0.5f, 0.8f, 0.6666667f });
}

/// <summary>
/// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/FBetaScore
/// </summary>
[TestMethod]
public void FBetaScore()
{
var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
var m = tf.keras.metrics.FBetaScore(num_classes: 3, beta: 2.0f, threshold: 0.5f);
m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, new[] { 0.3846154f, 0.90909094f, 0.8333334f });
}

/// <summary> /// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
/// </summary> /// </summary>


Loading…
Cancel
Save