Browse Source

Add metrics of BinaryAccuracy, CategoricalAccuracy, CategoricalCrossentropy.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
9e877d1c15
12 changed files with 244 additions and 20 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs
  2. +42
    -5
      src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
  3. +16
    -8
      src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
  4. +22
    -0
      src/TensorFlowNET.Keras/Engine/Model.Compile.cs
  5. +2
    -2
      src/TensorFlowNET.Keras/Engine/Model.Metrics.cs
  6. +11
    -0
      src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs
  7. +12
    -0
      src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs
  8. +16
    -0
      src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs
  9. +21
    -0
      src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
  10. +39
    -1
      src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
  11. +2
    -4
      src/TensorFlowNET.Keras/Utils/losses_utils.cs
  12. +60
    -0
      test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs

+ 1
- 0
src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs View File

@@ -2,6 +2,7 @@

public interface IMetricFunc
{
string Name { get; }
/// <summary>
/// Accumulates metric statistics.
/// </summary>


+ 42
- 5
src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs View File

@@ -5,6 +5,10 @@ public interface IMetricsApi
Tensor binary_accuracy(Tensor y_true, Tensor y_pred);

Tensor categorical_accuracy(Tensor y_true, Tensor y_pred);
Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred,
bool from_logits = false,
float label_smoothing = 0f,
Axis? axis = null);

Tensor mean_absolute_error(Tensor y_true, Tensor y_pred);

@@ -27,14 +31,39 @@ public interface IMetricsApi
/// <returns></returns>
Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5);

/// <summary>
/// Calculates how often predictions match binary labels.
/// </summary>
/// <returns></returns>
IMetricFunc BinaryAccuracy(string name = "binary_accuracy",
TF_DataType dtype = TF_DataType.TF_FLOAT,
float threshold = 05f);

/// <summary>
/// Calculates how often predictions match one-hot labels.
/// </summary>
/// <returns></returns>
IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy",
TF_DataType dtype = TF_DataType.TF_FLOAT,
bool from_logits = false,
float label_smoothing = 0f,
Axis? axis = null);

/// <summary>
/// Computes the crossentropy metric between the labels and predictions.
/// </summary>
/// <returns></returns>
IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy",
TF_DataType dtype = TF_DataType.TF_FLOAT);

/// <summary>
/// Computes how often targets are in the top K predictions.
/// </summary>
/// <param name="y_true"></param>
/// <param name="y_pred"></param>
/// <param name="k"></param>
/// <returns></returns>
IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT);
IMetricFunc TopKCategoricalAccuracy(int k = 5,
string name = "top_k_categorical_accuracy",
TF_DataType dtype = TF_DataType.TF_FLOAT);

/// <summary>
/// Computes the precision of the predictions with respect to the labels.
@@ -45,7 +74,11 @@ public interface IMetricsApi
/// <param name="name"></param>
/// <param name="dtype"></param>
/// <returns></returns>
IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
IMetricFunc Precision(float thresholds = 0.5f,
int top_k = 0,
int class_id = 0,
string name = "recall",
TF_DataType dtype = TF_DataType.TF_FLOAT);

/// <summary>
/// Computes the recall of the predictions with respect to the labels.
@@ -56,5 +89,9 @@ public interface IMetricsApi
/// <param name="name"></param>
/// <param name="dtype"></param>
/// <returns></returns>
IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
IMetricFunc Recall(float thresholds = 0.5f,
int top_k = 0,
int class_id = 0,
string name = "recall",
TF_DataType dtype = TF_DataType.TF_FLOAT);
}

+ 16
- 8
src/TensorFlowNET.Keras/Engine/MetricsContainer.cs View File

@@ -9,15 +9,21 @@ namespace Tensorflow.Keras.Engine
{
public class MetricsContainer : Container
{
string[] _user_metrics;
string[] _metric_names;
Metric[] _metrics;
List<Metric> _metrics_in_order;
IMetricFunc[] _user_metrics = new IMetricFunc[0];
string[] _metric_names = new string[0];
Metric[] _metrics = new Metric[0];
List<IMetricFunc> _metrics_in_order = new List<IMetricFunc>();

public MetricsContainer(string[] metrics, string[] output_names = null)
public MetricsContainer(IMetricFunc[] metrics, string[] output_names = null)
: base(output_names)
{
_user_metrics = metrics;
_built = false;
}

public MetricsContainer(string[] metrics, string[] output_names = null)
: base(output_names)
{
_metric_names = metrics;
_built = false;
}
@@ -46,9 +52,11 @@ namespace Tensorflow.Keras.Engine

void _create_ordered_metrics()
{
_metrics_in_order = new List<Metric>();
foreach (var m in _metrics)
_metrics_in_order.append(m);

foreach(var m in _user_metrics)
_metrics_in_order.append(m);
}

Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p)
@@ -56,7 +64,7 @@ namespace Tensorflow.Keras.Engine
return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray();
}

Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
public Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
{
Func<Tensor, Tensor, Tensor> metric_obj = null;
if (metric == "accuracy" || metric == "acc")
@@ -94,7 +102,7 @@ namespace Tensorflow.Keras.Engine
return new MeanMetricWrapper(metric_obj, metric);
}

public IEnumerable<Metric> metrics
public IEnumerable<IMetricFunc> metrics
{
get
{


+ 22
- 0
src/TensorFlowNET.Keras/Engine/Model.Compile.cs View File

@@ -1,6 +1,7 @@
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers;

namespace Tensorflow.Keras.Engine
@@ -31,6 +32,27 @@ namespace Tensorflow.Keras.Engine
_is_compiled = true;
}

public void compile(OptimizerV2 optimizer = null,
ILossFunc loss = null,
IMetricFunc[] metrics = null)
{
this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
{
});

this.loss = loss ?? new MeanSquaredError();

compiled_loss = new LossesContainer(loss, output_names: output_names);
compiled_metrics = new MetricsContainer(metrics, output_names: output_names);

int experimental_steps_per_execution = 1;
_configure_steps_per_execution(experimental_steps_per_execution);

// Initialize cache attrs.
_reset_compile_cache();
_is_compiled = true;
}

public void compile(string optimizer, string loss, string[] metrics)
{
var _optimizer = optimizer switch


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Model.Metrics.cs View File

@@ -5,11 +5,11 @@ namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public IEnumerable<Metric> metrics
public IEnumerable<IMetricFunc> metrics
{
get
{
var _metrics = new List<Metric>();
var _metrics = new List<IMetricFunc>();

if (_is_compiled)
{


+ 11
- 0
src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs View File

@@ -0,0 +1,11 @@
namespace Tensorflow.Keras.Metrics;

public class BinaryAccuracy : MeanMetricWrapper
{
public BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 0.5f)
: base((yt, yp) => metrics_utils.binary_matches(yt, yp),
name: name,
dtype: dtype)
{
}
}

+ 12
- 0
src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs View File

@@ -0,0 +1,12 @@
namespace Tensorflow.Keras.Metrics;

public class CategoricalAccuracy : MeanMetricWrapper
{
public CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
: base((yt, yp) => metrics_utils.sparse_categorical_matches(
tf.math.argmax(yt, axis: -1), yp),
name: name,
dtype: dtype)
{
}
}

+ 16
- 0
src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs View File

@@ -0,0 +1,16 @@
namespace Tensorflow.Keras.Metrics;

public class CategoricalCrossentropy : MeanMetricWrapper
{
public CategoricalCrossentropy(string name = "categorical_crossentropy",
TF_DataType dtype = TF_DataType.TF_FLOAT,
bool from_logits = false,
float label_smoothing = 0f,
Axis? axis = null)
: base((yt, yp) => keras.metrics.categorical_crossentropy(
yt, yp, from_logits: from_logits, label_smoothing: label_smoothing, axis: axis ?? -1),
name: name,
dtype: dtype)
{
}
}

+ 21
- 0
src/TensorFlowNET.Keras/Metrics/MetricsApi.cs View File

@@ -15,6 +15,18 @@
return math_ops.cast(eql, TF_DataType.TF_FLOAT);
}

public Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0, Axis? axis = null)
{
y_true = tf.cast(y_true, y_pred.dtype);
// var label_smoothing_tensor = tf.convert_to_tensor(label_smoothing, dtype: y_pred.dtype);
if (label_smoothing > 0)
{
var num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype);
y_true = y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes);
}
return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis);
}

/// <summary>
/// Calculates how often predictions matches integer labels.
/// </summary>
@@ -59,6 +71,15 @@
);
}

public IMetricFunc BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 5)
=> new BinaryAccuracy();

public IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new CategoricalAccuracy(name: name, dtype: dtype);

public IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, float label_smoothing = 0, Axis? axis = null)
=> new CategoricalCrossentropy();

public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);



+ 39
- 1
src/TensorFlowNET.Keras/Metrics/metrics_utils.cs View File

@@ -1,10 +1,48 @@
using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.Metrics;

public class metrics_utils
{
public static Tensor binary_matches(Tensor y_true, Tensor y_pred, float threshold = 0.5f)
{
y_pred = tf.cast(y_pred > threshold, y_pred.dtype);
return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx());
}

/// <summary>
/// Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch.
/// </summary>
/// <param name="y_true"></param>
/// <param name="y_pred"></param>
/// <returns></returns>
public static Tensor sparse_categorical_matches(Tensor y_true, Tensor y_pred)
{
var reshape_matches = false;
var y_true_rank = y_true.shape.ndim;
var y_pred_rank = y_pred.shape.ndim;
var y_true_org_shape = tf.shape(y_true);

if (y_true_rank > -1 && y_pred_rank > -1 && y_true.ndim == y_pred.ndim )
{
reshape_matches = true;
y_true = tf.squeeze(y_true, new Shape(-1));
}
y_pred = tf.math.argmax(y_pred, axis: -1);

var matches = tf.cast(
tf.equal(y_true, y_pred),
dtype: keras.backend.floatx()
);

if (reshape_matches)
{
return tf.reshape(matches, shape: y_true_org_shape);
}

return matches;
}

public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5)
{
var reshape_matches = false;


+ 2
- 4
src/TensorFlowNET.Keras/Utils/losses_utils.cs View File

@@ -75,10 +75,8 @@ namespace Tensorflow.Keras.Utils
{
sample_weight = tf.expand_dims(sample_weight, -1);
}
else
{
return (y_pred, y_true, sample_weight);
}

return (y_pred, y_true, sample_weight);
}

throw new NotImplementedException("");


+ 60
- 0
test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs View File

@@ -14,6 +14,66 @@ namespace TensorFlowNET.Keras.UnitTest;
[TestClass]
public class MetricsTest : EagerModeTestBase
{
/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy
/// </summary>
[TestMethod]
public void BinaryAccuracy()
{
var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } });
var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } });
var m = tf.keras.metrics.BinaryAccuracy();
/*m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, 0.75f);

m.reset_states();*/
var weights = np.array(new[] { 1f, 0f, 0f, 1f });
m.update_state(y_true, y_pred, sample_weight: weights);
var r = m.result().numpy();
Assert.AreEqual(r, 0.5f);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalAccuracy
/// </summary>
[TestMethod]
public void CategoricalAccuracy()
{
var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
var m = tf.keras.metrics.CategoricalAccuracy();
m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, 0.5f);

m.reset_states();
var weights = np.array(new[] { 0.7f, 0.3f });
m.update_state(y_true, y_pred, sample_weight: weights);
r = m.result().numpy();
Assert.AreEqual(r, 0.3f);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy
/// </summary>
[TestMethod]
public void CategoricalCrossentropy()
{
var y_true = np.array(new[,] { { 0, 1, 0 }, { 0, 0, 1 } });
var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } });
var m = tf.keras.metrics.CategoricalCrossentropy();
m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, 1.1769392f);

m.reset_states();
var weights = np.array(new[] { 0.3f, 0.7f });
m.update_state(y_true, y_pred, sample_weight: weights);
r = m.result().numpy();
Assert.AreEqual(r, 1.6271976f);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
/// </summary>


Loading…
Cancel
Save