Browse Source

Add metrics of F1Score and FBetaScore.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
067c1ff92a
8 changed files with 214 additions and 6 deletions
  1. +9
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +21
    -0
      src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  4. +13
    -0
      src/TensorFlowNET.Keras/Metrics/F1Score.cs
  5. +131
    -0
      src/TensorFlowNET.Keras/Metrics/FBetaScore.cs
  6. +6
    -0
      src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
  7. +3
    -3
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  8. +28
    -0
      test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs

+ 9
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -36,6 +36,15 @@ namespace Tensorflow
public Tensor erf(Tensor x, string name = null)
=> math_ops.erf(x, name);

public Tensor multiply(Tensor x, Tensor y, string name = null)
=> math_ops.multiply(x, y, name: name);

public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
=> math_ops.div_no_nan(a, b);

public Tensor square(Tensor x, string name = null)
=> math_ops.square(x, name: name);

public Tensor sum(Tensor x, Axis? axis = null, string name = null)
=> math_ops.reduce_sum(x, axis: axis, name: name);



+ 21
- 0
src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs View File

@@ -71,6 +71,27 @@ public interface IMetricsApi
TF_DataType dtype = TF_DataType.TF_FLOAT,
Axis? axis = null);

/// <summary>
/// Computes F-1 Score.
/// </summary>
/// <returns></returns>
IMetricFunc F1Score(int num_classes,
string? average = null,
float threshold = -1f,
string name = "fbeta_score",
TF_DataType dtype = TF_DataType.TF_FLOAT);

/// <summary>
/// Computes F-Beta score.
/// </summary>
/// <returns></returns>
IMetricFunc FBetaScore(int num_classes,
string? average = null,
float beta = 0.1f,
float threshold = -1f,
string name = "fbeta_score",
TF_DataType dtype = TF_DataType.TF_FLOAT);
/// <summary>
/// Computes how often targets are in the top K predictions.
/// </summary>


+ 3
- 3
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>Tensorflow.Binding</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.10.0</TargetTensorFlow>
<Version>0.100.3</Version>
<Version>0.100.4</Version>
<LangVersion>10.0</LangVersion>
<Nullable>enable</Nullable>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
@@ -20,7 +20,7 @@
<Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.100.3.0</AssemblyVersion>
<AssemblyVersion>0.100.4.0</AssemblyVersion>
<PackageReleaseNotes>
tf.net 0.100.x and above are based on tensorflow native 2.10.0

@@ -38,7 +38,7 @@ https://tensorflownet.readthedocs.io</Description>
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library.
</PackageReleaseNotes>
<FileVersion>0.100.3.0</FileVersion>
<FileVersion>0.100.4.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>


+ 13
- 0
src/TensorFlowNET.Keras/Metrics/F1Score.cs View File

@@ -0,0 +1,13 @@
namespace Tensorflow.Keras.Metrics;

public class F1Score : FBetaScore
{
public F1Score(int num_classes,
string? average = null,
float? threshold = -1f,
string name = "f1_score",
TF_DataType dtype = TF_DataType.TF_FLOAT)
: base(num_classes, average: average, threshold: threshold, beta: 1f, name: name, dtype: dtype)
{
}
}

+ 131
- 0
src/TensorFlowNET.Keras/Metrics/FBetaScore.cs View File

@@ -0,0 +1,131 @@
namespace Tensorflow.Keras.Metrics;

public class FBetaScore : Metric
{
int _num_classes;
string? _average;
Tensor _beta;
Tensor _threshold;
Axis _axis;
int[] _init_shape;

IVariableV1 true_positives;
IVariableV1 false_positives;
IVariableV1 false_negatives;
IVariableV1 weights_intermediate;

public FBetaScore(int num_classes,
string? average = null,
float beta = 0.1f,
float? threshold = -1f,
string name = "fbeta_score",
TF_DataType dtype = TF_DataType.TF_FLOAT)
: base(name: name, dtype: dtype)
{
_num_classes = num_classes;
_average = average;
_beta = constant_op.constant(beta);
_dtype = dtype;

if (threshold.HasValue)
{
_threshold = constant_op.constant(threshold);
}
_init_shape = new int[0];

if (average != "micro")
{
_axis = 0;
_init_shape = new int[] { num_classes };
}

true_positives = add_weight("true_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
false_positives = add_weight("false_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
false_negatives = add_weight("false_negatives", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
weights_intermediate = add_weight("weights_intermediate", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
}

public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
if (_threshold == null)
{
_threshold = tf.reduce_max(y_pred, axis: -1, keepdims: true);
// make sure [0, 0, 0] doesn't become [1, 1, 1]
// Use abs(x) > eps, instead of x != 0 to check for zero
y_pred = tf.logical_and(y_pred >= _threshold, tf.abs(y_pred) > 1e-12);
}
else
{
y_pred = y_pred > _threshold;
}

y_true = tf.cast(y_true, _dtype);
y_pred = tf.cast(y_pred, _dtype);

true_positives.assign_add(_weighted_sum(y_pred * y_true, sample_weight));
false_positives.assign_add(
_weighted_sum(y_pred * (1 - y_true), sample_weight)
);
false_negatives.assign_add(
_weighted_sum((1 - y_pred) * y_true, sample_weight)
);
weights_intermediate.assign_add(_weighted_sum(y_true, sample_weight));

return weights_intermediate.AsTensor();
}

Tensor _weighted_sum(Tensor val, Tensor? sample_weight = null)
{
if (sample_weight != null)
{
val = tf.math.multiply(val, tf.expand_dims(sample_weight, 1));
}
return tf.reduce_sum(val, axis: _axis);
}

public override Tensor result()
{
var precision = tf.math.divide_no_nan(
true_positives.AsTensor(), true_positives.AsTensor() + false_positives.AsTensor()
);
var recall = tf.math.divide_no_nan(
true_positives.AsTensor(), true_positives.AsTensor() + false_negatives.AsTensor()
);

var mul_value = precision * recall;
var add_value = (tf.math.square(_beta) * precision) + recall;
var mean = tf.math.divide_no_nan(mul_value, add_value);
var f1_score = mean * (1 + tf.math.square(_beta));

Tensor weights;
if (_average == "weighted")
{
weights = tf.math.divide_no_nan(
weights_intermediate.AsTensor(), tf.reduce_sum(weights_intermediate.AsTensor())
);
f1_score = tf.reduce_sum(f1_score * weights);
}
// micro, macro
else if (_average != null)
{
f1_score = tf.reduce_mean(f1_score);
}

return f1_score;
}

public override void reset_states()
{
var reset_value = np.zeros(_init_shape, dtype: _dtype);
keras.backend.batch_set_value(
new List<(IVariableV1, NDArray)>
{
(true_positives, reset_value),
(false_positives, reset_value),
(false_negatives, reset_value),
(weights_intermediate, reset_value)
});
}
}

+ 6
- 0
src/TensorFlowNET.Keras/Metrics/MetricsApi.cs View File

@@ -86,6 +86,12 @@
public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null)
=> new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1);

public IMetricFunc F1Score(int num_classes, string? average = null, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new F1Score(num_classes, average: average, threshold: threshold, name: name, dtype: dtype);

public IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1F, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new FBetaScore(num_classes, average: average,beta: beta, threshold: threshold, name: name, dtype: dtype);

public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);



+ 3
- 3
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -7,7 +7,7 @@
<Nullable>enable</Nullable>
<RootNamespace>Tensorflow.Keras</RootNamespace>
<Platforms>AnyCPU;x64</Platforms>
<Version>0.10.3</Version>
<Version>0.10.4</Version>
<Authors>Haiping Chen</Authors>
<Product>Keras for .NET</Product>
<Copyright>Apache 2.0, Haiping Chen 2023</Copyright>
@@ -37,8 +37,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<AssemblyVersion>0.10.3.0</AssemblyVersion>
<FileVersion>0.10.3.0</FileVersion>
<AssemblyVersion>0.10.4.0</AssemblyVersion>
<FileVersion>0.10.4.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<Configurations>Debug;Release;GPU</Configurations>
</PropertyGroup>


+ 28
- 0
test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs View File

@@ -114,6 +114,34 @@ public class MetricsTest : EagerModeTestBase
Assert.AreEqual(r, 0.6999999f);
}

/// <summary>
/// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score
/// </summary>
[TestMethod]
public void F1Score()
{
var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
var m = tf.keras.metrics.F1Score(num_classes: 3, threshold: 0.5f);
m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, new[] { 0.5f, 0.8f, 0.6666667f });
}

/// <summary>
/// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/FBetaScore
/// </summary>
[TestMethod]
public void FBetaScore()
{
var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
var m = tf.keras.metrics.FBetaScore(num_classes: 3, beta: 2.0f, threshold: 0.5f);
m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, new[] { 0.3846154f, 0.90909094f, 0.8333334f });
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
/// </summary>


Loading…
Cancel
Save