@@ -36,6 +36,15 @@ namespace Tensorflow | |||||
public Tensor erf(Tensor x, string name = null) | public Tensor erf(Tensor x, string name = null) | ||||
=> math_ops.erf(x, name); | => math_ops.erf(x, name); | ||||
public Tensor multiply(Tensor x, Tensor y, string name = null) | |||||
=> math_ops.multiply(x, y, name: name); | |||||
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) | |||||
=> math_ops.div_no_nan(a, b); | |||||
public Tensor square(Tensor x, string name = null) | |||||
=> math_ops.square(x, name: name); | |||||
public Tensor sum(Tensor x, Axis? axis = null, string name = null) | public Tensor sum(Tensor x, Axis? axis = null, string name = null) | ||||
=> math_ops.reduce_sum(x, axis: axis, name: name); | => math_ops.reduce_sum(x, axis: axis, name: name); | ||||
@@ -71,6 +71,27 @@ public interface IMetricsApi | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
Axis? axis = null); | Axis? axis = null); | ||||
/// <summary> | |||||
/// Computes F-1 Score. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
IMetricFunc F1Score(int num_classes, | |||||
string? average = null, | |||||
float threshold = -1f, | |||||
string name = "fbeta_score", | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
/// <summary> | |||||
/// Computes F-Beta score. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
IMetricFunc FBetaScore(int num_classes, | |||||
string? average = null, | |||||
float beta = 0.1f, | |||||
float threshold = -1f, | |||||
string name = "fbeta_score", | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT); | |||||
/// <summary> | /// <summary> | ||||
/// Computes how often targets are in the top K predictions. | /// Computes how often targets are in the top K predictions. | ||||
/// </summary> | /// </summary> | ||||
@@ -5,7 +5,7 @@ | |||||
<AssemblyName>Tensorflow.Binding</AssemblyName> | <AssemblyName>Tensorflow.Binding</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>2.10.0</TargetTensorFlow> | <TargetTensorFlow>2.10.0</TargetTensorFlow> | ||||
<Version>0.100.3</Version> | |||||
<Version>0.100.4</Version> | |||||
<LangVersion>10.0</LangVersion> | <LangVersion>10.0</LangVersion> | ||||
<Nullable>enable</Nullable> | <Nullable>enable</Nullable> | ||||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
@@ -20,7 +20,7 @@ | |||||
<Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
Building, training and infering deep learning models. | Building, training and infering deep learning models. | ||||
https://tensorflownet.readthedocs.io</Description> | https://tensorflownet.readthedocs.io</Description> | ||||
<AssemblyVersion>0.100.3.0</AssemblyVersion> | |||||
<AssemblyVersion>0.100.4.0</AssemblyVersion> | |||||
<PackageReleaseNotes> | <PackageReleaseNotes> | ||||
tf.net 0.100.x and above are based on tensorflow native 2.10.0 | tf.net 0.100.x and above are based on tensorflow native 2.10.0 | ||||
@@ -38,7 +38,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library. | tf.net 0.7x.x aligns with TensorFlow v2.7.x native library. | ||||
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library. | tf.net 0.10x.x aligns with TensorFlow v2.10.x native library. | ||||
</PackageReleaseNotes> | </PackageReleaseNotes> | ||||
<FileVersion>0.100.3.0</FileVersion> | |||||
<FileVersion>0.100.4.0</FileVersion> | |||||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
@@ -0,0 +1,13 @@ | |||||
namespace Tensorflow.Keras.Metrics; | |||||
public class F1Score : FBetaScore | |||||
{ | |||||
public F1Score(int num_classes, | |||||
string? average = null, | |||||
float? threshold = -1f, | |||||
string name = "f1_score", | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
: base(num_classes, average: average, threshold: threshold, beta: 1f, name: name, dtype: dtype) | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,131 @@ | |||||
namespace Tensorflow.Keras.Metrics; | |||||
public class FBetaScore : Metric | |||||
{ | |||||
int _num_classes; | |||||
string? _average; | |||||
Tensor _beta; | |||||
Tensor _threshold; | |||||
Axis _axis; | |||||
int[] _init_shape; | |||||
IVariableV1 true_positives; | |||||
IVariableV1 false_positives; | |||||
IVariableV1 false_negatives; | |||||
IVariableV1 weights_intermediate; | |||||
public FBetaScore(int num_classes, | |||||
string? average = null, | |||||
float beta = 0.1f, | |||||
float? threshold = -1f, | |||||
string name = "fbeta_score", | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
: base(name: name, dtype: dtype) | |||||
{ | |||||
_num_classes = num_classes; | |||||
_average = average; | |||||
_beta = constant_op.constant(beta); | |||||
_dtype = dtype; | |||||
if (threshold.HasValue) | |||||
{ | |||||
_threshold = constant_op.constant(threshold); | |||||
} | |||||
_init_shape = new int[0]; | |||||
if (average != "micro") | |||||
{ | |||||
_axis = 0; | |||||
_init_shape = new int[] { num_classes }; | |||||
} | |||||
true_positives = add_weight("true_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); | |||||
false_positives = add_weight("false_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); | |||||
false_negatives = add_weight("false_negatives", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); | |||||
weights_intermediate = add_weight("weights_intermediate", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); | |||||
} | |||||
public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||||
{ | |||||
if (_threshold == null) | |||||
{ | |||||
_threshold = tf.reduce_max(y_pred, axis: -1, keepdims: true); | |||||
// make sure [0, 0, 0] doesn't become [1, 1, 1] | |||||
// Use abs(x) > eps, instead of x != 0 to check for zero | |||||
y_pred = tf.logical_and(y_pred >= _threshold, tf.abs(y_pred) > 1e-12); | |||||
} | |||||
else | |||||
{ | |||||
y_pred = y_pred > _threshold; | |||||
} | |||||
y_true = tf.cast(y_true, _dtype); | |||||
y_pred = tf.cast(y_pred, _dtype); | |||||
true_positives.assign_add(_weighted_sum(y_pred * y_true, sample_weight)); | |||||
false_positives.assign_add( | |||||
_weighted_sum(y_pred * (1 - y_true), sample_weight) | |||||
); | |||||
false_negatives.assign_add( | |||||
_weighted_sum((1 - y_pred) * y_true, sample_weight) | |||||
); | |||||
weights_intermediate.assign_add(_weighted_sum(y_true, sample_weight)); | |||||
return weights_intermediate.AsTensor(); | |||||
} | |||||
Tensor _weighted_sum(Tensor val, Tensor? sample_weight = null) | |||||
{ | |||||
if (sample_weight != null) | |||||
{ | |||||
val = tf.math.multiply(val, tf.expand_dims(sample_weight, 1)); | |||||
} | |||||
return tf.reduce_sum(val, axis: _axis); | |||||
} | |||||
public override Tensor result() | |||||
{ | |||||
var precision = tf.math.divide_no_nan( | |||||
true_positives.AsTensor(), true_positives.AsTensor() + false_positives.AsTensor() | |||||
); | |||||
var recall = tf.math.divide_no_nan( | |||||
true_positives.AsTensor(), true_positives.AsTensor() + false_negatives.AsTensor() | |||||
); | |||||
var mul_value = precision * recall; | |||||
var add_value = (tf.math.square(_beta) * precision) + recall; | |||||
var mean = tf.math.divide_no_nan(mul_value, add_value); | |||||
var f1_score = mean * (1 + tf.math.square(_beta)); | |||||
Tensor weights; | |||||
if (_average == "weighted") | |||||
{ | |||||
weights = tf.math.divide_no_nan( | |||||
weights_intermediate.AsTensor(), tf.reduce_sum(weights_intermediate.AsTensor()) | |||||
); | |||||
f1_score = tf.reduce_sum(f1_score * weights); | |||||
} | |||||
// micro, macro | |||||
else if (_average != null) | |||||
{ | |||||
f1_score = tf.reduce_mean(f1_score); | |||||
} | |||||
return f1_score; | |||||
} | |||||
public override void reset_states() | |||||
{ | |||||
var reset_value = np.zeros(_init_shape, dtype: _dtype); | |||||
keras.backend.batch_set_value( | |||||
new List<(IVariableV1, NDArray)> | |||||
{ | |||||
(true_positives, reset_value), | |||||
(false_positives, reset_value), | |||||
(false_negatives, reset_value), | |||||
(weights_intermediate, reset_value) | |||||
}); | |||||
} | |||||
} |
@@ -86,6 +86,12 @@ | |||||
public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null) | public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null) | ||||
=> new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1); | => new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1); | ||||
public IMetricFunc F1Score(int num_classes, string? average = null, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
=> new F1Score(num_classes, average: average, threshold: threshold, name: name, dtype: dtype); | |||||
public IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1F, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
=> new FBetaScore(num_classes, average: average,beta: beta, threshold: threshold, name: name, dtype: dtype); | |||||
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | ||||
@@ -7,7 +7,7 @@ | |||||
<Nullable>enable</Nullable> | <Nullable>enable</Nullable> | ||||
<RootNamespace>Tensorflow.Keras</RootNamespace> | <RootNamespace>Tensorflow.Keras</RootNamespace> | ||||
<Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
<Version>0.10.3</Version> | |||||
<Version>0.10.4</Version> | |||||
<Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
<Product>Keras for .NET</Product> | <Product>Keras for .NET</Product> | ||||
<Copyright>Apache 2.0, Haiping Chen 2023</Copyright> | <Copyright>Apache 2.0, Haiping Chen 2023</Copyright> | ||||
@@ -37,8 +37,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
<RepositoryType>Git</RepositoryType> | <RepositoryType>Git</RepositoryType> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | ||||
<AssemblyVersion>0.10.3.0</AssemblyVersion> | |||||
<FileVersion>0.10.3.0</FileVersion> | |||||
<AssemblyVersion>0.10.4.0</AssemblyVersion> | |||||
<FileVersion>0.10.4.0</FileVersion> | |||||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
<Configurations>Debug;Release;GPU</Configurations> | <Configurations>Debug;Release;GPU</Configurations> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -114,6 +114,34 @@ public class MetricsTest : EagerModeTestBase | |||||
Assert.AreEqual(r, 0.6999999f); | Assert.AreEqual(r, 0.6999999f); | ||||
} | } | ||||
/// <summary> | |||||
/// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void F1Score() | |||||
{ | |||||
var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } }); | |||||
var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } }); | |||||
var m = tf.keras.metrics.F1Score(num_classes: 3, threshold: 0.5f); | |||||
m.update_state(y_true, y_pred); | |||||
var r = m.result().numpy(); | |||||
Assert.AreEqual(r, new[] { 0.5f, 0.8f, 0.6666667f }); | |||||
} | |||||
/// <summary> | |||||
/// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/FBetaScore | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void FBetaScore() | |||||
{ | |||||
var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } }); | |||||
var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } }); | |||||
var m = tf.keras.metrics.FBetaScore(num_classes: 3, beta: 2.0f, threshold: 0.5f); | |||||
m.update_state(y_true, y_pred); | |||||
var r = m.result().numpy(); | |||||
Assert.AreEqual(r, new[] { 0.3846154f, 0.90909094f, 0.8333334f }); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy | /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy | ||||
/// </summary> | /// </summary> | ||||