Browse Source

Add mertic of Recall.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
217cfd2d41
10 changed files with 269 additions and 8 deletions
  1. +11
    -0
      src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +3
    -1
      src/TensorFlowNET.Keras/GlobalUsing.cs
  4. +1
    -1
      src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
  5. +3
    -0
      src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
  6. +53
    -0
      src/TensorFlowNET.Keras/Metrics/Recall.cs
  7. +1
    -1
      src/TensorFlowNET.Keras/Metrics/Reduce.cs
  8. +170
    -1
      src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
  9. +4
    -4
      src/TensorFlowNET.Keras/Utils/losses_utils.cs
  10. +20
    -0
      test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs

+ 11
- 0
src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs View File

@@ -35,4 +35,15 @@ public interface IMetricsApi
/// <param name="k"></param>
/// <returns></returns>
IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT);

/// <summary>
/// Computes the recall of the predictions with respect to the labels.
/// </summary>
/// <param name="thresholds"></param>
/// <param name="top_k"></param>
/// <param name="class_id"></param>
/// <param name="name"></param>
/// <param name="dtype"></param>
/// <returns></returns>
IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
}

+ 3
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -221,6 +221,9 @@ namespace Tensorflow
case Tensor t:
dtype = t.dtype.as_base_dtype();
break;
case int t:
dtype = TF_DataType.TF_INT32;
break;
}

if (dtype != TF_DataType.DtInvalid)


+ 3
- 1
src/TensorFlowNET.Keras/GlobalUsing.cs View File

@@ -1,5 +1,7 @@
global using System;
global using System.Collections.Generic;
global using System.Text;
global using System.Linq;
global using static Tensorflow.Binding;
global using static Tensorflow.KerasApi;
global using static Tensorflow.KerasApi;
global using Tensorflow.NumPy;

+ 1
- 1
src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Metrics
y_true = math_ops.cast(y_true, _dtype);
y_pred = math_ops.cast(y_pred, _dtype);

(y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true);
(y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true);

var matches = _fn(y_true, y_pred);
return update_state(matches, sample_weight: sample_weight);


+ 3
- 0
src/TensorFlowNET.Keras/Metrics/MetricsApi.cs View File

@@ -61,5 +61,8 @@

public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);

public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT)
=> new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype);
}
}

+ 53
- 0
src/TensorFlowNET.Keras/Metrics/Recall.cs View File

@@ -0,0 +1,53 @@
namespace Tensorflow.Keras.Metrics;

public class Recall : Metric
{
Tensor _thresholds;
int _top_k;
int _class_id;
IVariableV1 true_positives;
IVariableV1 false_negatives;
bool _thresholds_distributed_evenly;

public Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT)
: base(name: name, dtype: dtype)
{
_thresholds = constant_op.constant(new float[] { thresholds });
true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer());
false_negatives = add_weight("false_negatives", shape: 1, initializer: tf.initializers.zeros_initializer());
}

public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
return metrics_utils.update_confusion_matrix_variables(
new Dictionary<string, IVariableV1>
{
{ "tp", true_positives },
{ "fn", false_negatives },
},
y_true,
y_pred,
thresholds: _thresholds,
thresholds_distributed_evenly: _thresholds_distributed_evenly,
top_k: _top_k,
class_id: _class_id,
sample_weight: sample_weight);
}

public override Tensor result()
{
var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_negatives));
return _thresholds.size == 1 ? result[0] : result;
}

public override void reset_states()
{
var num_thresholds = (int)_thresholds.size;
keras.backend.batch_set_value(
new List<(IVariableV1, NDArray)>
{
(true_positives, np.zeros(num_thresholds)),
(false_negatives, np.zeros(num_thresholds))
});
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Metrics/Reduce.cs View File

@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Metrics
{
if (sample_weight != null)
{
(values, sample_weight) = losses_utils.squeeze_or_expand_dimensions(
(values, _, sample_weight) = losses_utils.squeeze_or_expand_dimensions(
values, sample_weight: sample_weight);

sample_weight = math_ops.cast(sample_weight, dtype: values.dtype);


+ 170
- 1
src/TensorFlowNET.Keras/Metrics/metrics_utils.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.NumPy;
using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.Metrics;

@@ -36,4 +37,172 @@ public class metrics_utils
return matches;
}

public static Tensor update_confusion_matrix_variables(Dictionary<string, IVariableV1> variables_to_update,
Tensor y_true,
Tensor y_pred,
Tensor thresholds,
int top_k,
int class_id,
Tensor sample_weight = null,
bool multi_label = false,
Tensor label_weights = null,
bool thresholds_distributed_evenly = false)
{
var variable_dtype = variables_to_update.Values.First().dtype;
y_true = tf.cast(y_true, dtype: variable_dtype);
y_pred = tf.cast(y_pred, dtype: variable_dtype);
var num_thresholds = thresholds.shape.dims[0];

Tensor one_thresh = null;
if (multi_label)
{
one_thresh = tf.equal(tf.cast(constant_op.constant(1), dtype:tf.int32),
tf.rank(thresholds),
name: "one_set_of_thresholds_cond");
}
else
{
one_thresh = tf.cast(constant_op.constant(true), dtype: dtypes.@bool);
}

if (sample_weight == null)
{
(y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true);
}
else
{
sample_weight = tf.cast(sample_weight, dtype: variable_dtype);
(y_pred, y_true, sample_weight) = losses_utils.squeeze_or_expand_dimensions(y_pred,
y_true,
sample_weight: sample_weight);
}

if (thresholds_distributed_evenly)
{
throw new NotImplementedException();
}

var pred_shape = tf.shape(y_pred);
var num_predictions = pred_shape[0];

Tensor num_labels;
if (y_pred.shape.ndim == 1)
{
num_labels = constant_op.constant(1);
}
else
{
num_labels = tf.reduce_prod(pred_shape["1:"], axis: 0);
}
var thresh_label_tile = tf.where(one_thresh, num_labels, tf.ones(new int[0], dtype: tf.int32));

// Reshape predictions and labels, adding a dim for thresholding.
Tensor predictions_extra_dim, labels_extra_dim;
if (multi_label)
{
predictions_extra_dim = tf.expand_dims(y_pred, 0);
labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype: tf.@bool), 0);
}

else
{
// Flatten predictions and labels when not multilabel.
predictions_extra_dim = tf.reshape(y_pred, (1, -1));
labels_extra_dim = tf.reshape(tf.cast(y_true, dtype: tf.@bool), (1, -1));
}

// Tile the thresholds for every prediction.
object[] thresh_pretile_shape, thresh_tiles, data_tiles;
if (multi_label)
{
thresh_pretile_shape = new object[] { num_thresholds, 1, -1 };
thresh_tiles = new object[] { 1, num_predictions, thresh_label_tile };
data_tiles = new object[] { num_thresholds, 1, 1 };
}
else
{
thresh_pretile_shape = new object[] { num_thresholds, -1 };
thresh_tiles = new object[] { 1, num_predictions * num_labels };
data_tiles = new object[] { num_thresholds, 1 };
}
var thresh_tiled = tf.tile(tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles));

// Tile the predictions for every threshold.
var preds_tiled = tf.tile(predictions_extra_dim, data_tiles);

// Compare predictions and threshold.
var pred_is_pos = tf.greater(preds_tiled, thresh_tiled);

// Tile labels by number of thresholds
var label_is_pos = tf.tile(labels_extra_dim, data_tiles);

Tensor weights_tiled = null;

if (sample_weight != null)
{
/*sample_weight = broadcast_weights(
tf.cast(sample_weight, dtype: variable_dtype), y_pred);*/
weights_tiled = tf.tile(
tf.reshape(sample_weight, thresh_tiles), data_tiles);
}

if (label_weights != null && !multi_label)
{
throw new NotImplementedException();
}

Func<Tensor, Tensor, Tensor, IVariableV1, ITensorOrOperation> weighted_assign_add
= (label, pred, weights, var) =>
{
var label_and_pred = tf.cast(tf.logical_and(label, pred), dtype: var.dtype);
if (weights != null)
{
label_and_pred *= tf.cast(weights, dtype: var.dtype);
}

return var.assign_add(tf.reduce_sum(label_and_pred, 1));
};


var loop_vars = new Dictionary<string, (Tensor, Tensor)>
{
{ "tp", (label_is_pos, pred_is_pos) }
};
var update_tn = variables_to_update.ContainsKey("tn");
var update_fp = variables_to_update.ContainsKey("fp");
var update_fn = variables_to_update.ContainsKey("fn");

Tensor pred_is_neg = null;
if (update_fn || update_tn)
{
pred_is_neg = tf.logical_not(pred_is_pos);
loop_vars["fn"] = (label_is_pos, pred_is_neg);
}

if(update_fp || update_tn)
{
var label_is_neg = tf.logical_not(label_is_pos);
loop_vars["fp"] = (label_is_neg, pred_is_pos);
if (update_tn)
{
loop_vars["tn"] = (label_is_neg, pred_is_neg);
}
}

var update_ops = new List<ITensorOrOperation>();
foreach (var matrix_cond in loop_vars.Keys)
{
var (label, pred) = loop_vars[matrix_cond];
if (variables_to_update.ContainsKey(matrix_cond))
{
var op = weighted_assign_add(label, pred, weights_tiled, variables_to_update[matrix_cond]);
update_ops.append(op);
}
}

tf.group(update_ops.ToArray());
return null;
}
}

+ 4
- 4
src/TensorFlowNET.Keras/Utils/losses_utils.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Utils
});
}

public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null)
public static (Tensor, Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null)
{
var y_pred_shape = y_pred.shape;
var y_pred_rank = y_pred_shape.ndim;
@@ -57,13 +57,13 @@ namespace Tensorflow.Keras.Utils

if (sample_weight == null)
{
return (y_pred, y_true);
return (y_pred, y_true, sample_weight);
}

var weights_shape = sample_weight.shape;
var weights_rank = weights_shape.ndim;
if (weights_rank == 0)
return (y_pred, sample_weight);
return (y_pred, y_true, sample_weight);

if (y_pred_rank > -1 && weights_rank > -1)
{
@@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Utils
}
else
{
return (y_pred, sample_weight);
return (y_pred, y_true, sample_weight);
}
}



+ 20
- 0
test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs View File

@@ -45,4 +45,24 @@ public class MetricsTest : EagerModeTestBase
var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3);
Assert.AreEqual(m.numpy(), new[] { 1f, 1f });
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall
/// </summary>
[TestMethod]
public void Recall()
{
var y_true = np.array(new[] { 0, 1, 1, 1 });
var y_pred = np.array(new[] { 1, 0, 1, 1 });
var m = tf.keras.metrics.Recall();
m.update_state(y_true, y_pred);
var r = m.result().numpy();
Assert.AreEqual(r, 0.6666667f);

m.reset_states();
var weights = np.array(new[] { 0f, 0f, 1f, 0f });
m.update_state(y_true, y_pred, sample_weight: weights);
r = m.result().numpy();
Assert.AreEqual(r, 1f);
}
}

Loading…
Cancel
Save