Browse Source

Abstract IMetricFunc.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
a5289b9bb3
8 changed files with 110 additions and 5 deletions
  1. +17
    -0
      src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs
  2. +9
    -0
      src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
  3. +3
    -0
      src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
  4. +1
    -1
      src/TensorFlowNET.Keras/Metrics/Metric.cs
  5. +4
    -3
      src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
  6. +12
    -0
      src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs
  7. +44
    -1
      src/TensorFlowNET.Keras/Utils/losses_utils.cs
  8. +20
    -0
      test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs

+ 17
- 0
src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs View File

@@ -0,0 +1,17 @@
namespace Tensorflow.Keras.Metrics;

public interface IMetricFunc
{
/// <summary>
/// Accumulates metric statistics.
/// </summary>
/// <param name="y_true"></param>
/// <param name="y_pred"></param>
/// <param name="sample_weight"></param>
/// <returns></returns>
Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null);

Tensor result();

void reset_states();
}

+ 9
- 0
src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs View File

@@ -26,4 +26,13 @@ public interface IMetricsApi
/// <param name="k"></param>
/// <returns></returns>
Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5);

/// <summary>
/// Computes how often targets are in the top K predictions.
/// </summary>
/// <param name="y_true"></param>
/// <param name="y_pred"></param>
/// <param name="k"></param>
/// <returns></returns>
IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT);
}

+ 3
- 0
src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs View File

@@ -1,4 +1,5 @@
using System;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Metrics
{
@@ -17,6 +18,8 @@ namespace Tensorflow.Keras.Metrics
y_true = math_ops.cast(y_true, _dtype);
y_pred = math_ops.cast(y_pred, _dtype);

(y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true);

var matches = _fn(y_true, y_pred);
return update_state(matches, sample_weight: sample_weight);
}


+ 1
- 1
src/TensorFlowNET.Keras/Metrics/Metric.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Metrics
/// <summary>
/// Encapsulates metric logic and state.
/// </summary>
public class Metric : Layer
public class Metric : Layer, IMetricFunc
{
protected IVariableV1 total;
protected IVariableV1 count;


+ 4
- 3
src/TensorFlowNET.Keras/Metrics/MetricsApi.cs View File

@@ -1,6 +1,4 @@
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Metrics
namespace Tensorflow.Keras.Metrics
{
public class MetricsApi : IMetricsApi
{
@@ -60,5 +58,8 @@ namespace Tensorflow.Keras.Metrics
tf.math.argmax(y_true, axis: -1), y_pred, k
);
}

public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
}
}

+ 12
- 0
src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs View File

@@ -0,0 +1,12 @@
namespace Tensorflow.Keras.Metrics;

public class TopKCategoricalAccuracy : MeanMetricWrapper
{
public TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
: base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches(
tf.math.argmax(yt, axis: -1), yp, k),
name: name,
dtype: dtype)
{
}
}

+ 44
- 1
src/TensorFlowNET.Keras/Utils/losses_utils.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System;
using System.Xml.Linq;
using Tensorflow.Keras.Losses;
using static Tensorflow.Binding;

@@ -37,15 +38,57 @@ namespace Tensorflow.Keras.Utils
});
}

public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight)
public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null)
{
var y_pred_shape = y_pred.shape;
var y_pred_rank = y_pred_shape.ndim;
if (y_true != null)
{
var y_true_shape = y_true.shape;
var y_true_rank = y_true_shape.ndim;
if (y_true_rank > -1 && y_pred_rank > -1)
{
if (y_pred_rank - y_true_rank != 1 || y_pred_shape[-1] == 1)
{
(y_true, y_pred) = remove_squeezable_dimensions(y_true, y_pred);
}
}
}

if (sample_weight == null)
{
return (y_pred, y_true);
}

var weights_shape = sample_weight.shape;
var weights_rank = weights_shape.ndim;
if (weights_rank == 0)
return (y_pred, sample_weight);

if (y_pred_rank > -1 && weights_rank > -1)
{
if (weights_rank - y_pred_rank == 1)
{
sample_weight = tf.squeeze(sample_weight, -1);
}
else if (y_pred_rank - weights_rank == 1)
{
sample_weight = tf.expand_dims(sample_weight, -1);
}
else
{
return (y_pred, sample_weight);
}
}

throw new NotImplementedException("");
}

public static (Tensor, Tensor) remove_squeezable_dimensions(Tensor labels, Tensor predictions, int expected_rank_diff = 0, string name = null)
{
return (labels, predictions);
}

public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction)
{
if (reduction == ReductionV2.NONE)


+ 20
- 0
test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs View File

@@ -14,6 +14,26 @@ namespace TensorFlowNET.Keras.UnitTest;
[TestClass]
public class MetricsTest : EagerModeTestBase
{
/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
/// </summary>
[TestMethod]
public void TopKCategoricalAccuracy()
{
var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1);
m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, 0.5f);

m.reset_states();
var weights = np.array(new[] { 0.7f, 0.3f });
m.update_state(y_true, y_pred, sample_weight: weights);
r = m.result().numpy();
Assert.AreEqual(r, 0.3f);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy
/// </summary>


Loading…
Cancel
Save