@@ -35,4 +35,15 @@ public interface IMetricsApi | |||
/// <param name="k"></param> | |||
/// <returns></returns> | |||
IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT); | |||
/// <summary> | |||
/// Computes the recall of the predictions with respect to the labels. | |||
/// </summary> | |||
/// <param name="thresholds"></param> | |||
/// <param name="top_k"></param> | |||
/// <param name="class_id"></param> | |||
/// <param name="name"></param> | |||
/// <param name="dtype"></param> | |||
/// <returns></returns> | |||
IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT); | |||
} |
@@ -221,6 +221,9 @@ namespace Tensorflow | |||
case Tensor t: | |||
dtype = t.dtype.as_base_dtype(); | |||
break; | |||
case int t: | |||
dtype = TF_DataType.TF_INT32; | |||
break; | |||
} | |||
if (dtype != TF_DataType.DtInvalid) | |||
@@ -1,5 +1,7 @@ | |||
global using System; | |||
global using System.Collections.Generic; | |||
global using System.Text; | |||
global using System.Linq; | |||
global using static Tensorflow.Binding; | |||
global using static Tensorflow.KerasApi; | |||
global using static Tensorflow.KerasApi; | |||
global using Tensorflow.NumPy; |
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Metrics | |||
y_true = math_ops.cast(y_true, _dtype); | |||
y_pred = math_ops.cast(y_pred, _dtype); | |||
(y_pred, y_true) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); | |||
(y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); | |||
var matches = _fn(y_true, y_pred); | |||
return update_state(matches, sample_weight: sample_weight); | |||
@@ -61,5 +61,8 @@ | |||
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); | |||
public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
=> new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); | |||
} | |||
} |
@@ -0,0 +1,53 @@ | |||
namespace Tensorflow.Keras.Metrics; | |||
public class Recall : Metric | |||
{ | |||
Tensor _thresholds; | |||
int _top_k; | |||
int _class_id; | |||
IVariableV1 true_positives; | |||
IVariableV1 false_negatives; | |||
bool _thresholds_distributed_evenly; | |||
public Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
: base(name: name, dtype: dtype) | |||
{ | |||
_thresholds = constant_op.constant(new float[] { thresholds }); | |||
true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer()); | |||
false_negatives = add_weight("false_negatives", shape: 1, initializer: tf.initializers.zeros_initializer()); | |||
} | |||
public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||
{ | |||
return metrics_utils.update_confusion_matrix_variables( | |||
new Dictionary<string, IVariableV1> | |||
{ | |||
{ "tp", true_positives }, | |||
{ "fn", false_negatives }, | |||
}, | |||
y_true, | |||
y_pred, | |||
thresholds: _thresholds, | |||
thresholds_distributed_evenly: _thresholds_distributed_evenly, | |||
top_k: _top_k, | |||
class_id: _class_id, | |||
sample_weight: sample_weight); | |||
} | |||
public override Tensor result() | |||
{ | |||
var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_negatives)); | |||
return _thresholds.size == 1 ? result[0] : result; | |||
} | |||
public override void reset_states() | |||
{ | |||
var num_thresholds = (int)_thresholds.size; | |||
keras.backend.batch_set_value( | |||
new List<(IVariableV1, NDArray)> | |||
{ | |||
(true_positives, np.zeros(num_thresholds)), | |||
(false_negatives, np.zeros(num_thresholds)) | |||
}); | |||
} | |||
} |
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Metrics | |||
{ | |||
if (sample_weight != null) | |||
{ | |||
(values, sample_weight) = losses_utils.squeeze_or_expand_dimensions( | |||
(values, _, sample_weight) = losses_utils.squeeze_or_expand_dimensions( | |||
values, sample_weight: sample_weight); | |||
sample_weight = math_ops.cast(sample_weight, dtype: values.dtype); | |||
@@ -1,4 +1,5 @@ | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Keras.Metrics; | |||
@@ -36,4 +37,172 @@ public class metrics_utils | |||
return matches; | |||
} | |||
public static Tensor update_confusion_matrix_variables(Dictionary<string, IVariableV1> variables_to_update, | |||
Tensor y_true, | |||
Tensor y_pred, | |||
Tensor thresholds, | |||
int top_k, | |||
int class_id, | |||
Tensor sample_weight = null, | |||
bool multi_label = false, | |||
Tensor label_weights = null, | |||
bool thresholds_distributed_evenly = false) | |||
{ | |||
var variable_dtype = variables_to_update.Values.First().dtype; | |||
y_true = tf.cast(y_true, dtype: variable_dtype); | |||
y_pred = tf.cast(y_pred, dtype: variable_dtype); | |||
var num_thresholds = thresholds.shape.dims[0]; | |||
Tensor one_thresh = null; | |||
if (multi_label) | |||
{ | |||
one_thresh = tf.equal(tf.cast(constant_op.constant(1), dtype:tf.int32), | |||
tf.rank(thresholds), | |||
name: "one_set_of_thresholds_cond"); | |||
} | |||
else | |||
{ | |||
one_thresh = tf.cast(constant_op.constant(true), dtype: dtypes.@bool); | |||
} | |||
if (sample_weight == null) | |||
{ | |||
(y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true); | |||
} | |||
else | |||
{ | |||
sample_weight = tf.cast(sample_weight, dtype: variable_dtype); | |||
(y_pred, y_true, sample_weight) = losses_utils.squeeze_or_expand_dimensions(y_pred, | |||
y_true, | |||
sample_weight: sample_weight); | |||
} | |||
if (thresholds_distributed_evenly) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
var pred_shape = tf.shape(y_pred); | |||
var num_predictions = pred_shape[0]; | |||
Tensor num_labels; | |||
if (y_pred.shape.ndim == 1) | |||
{ | |||
num_labels = constant_op.constant(1); | |||
} | |||
else | |||
{ | |||
num_labels = tf.reduce_prod(pred_shape["1:"], axis: 0); | |||
} | |||
var thresh_label_tile = tf.where(one_thresh, num_labels, tf.ones(new int[0], dtype: tf.int32)); | |||
// Reshape predictions and labels, adding a dim for thresholding. | |||
Tensor predictions_extra_dim, labels_extra_dim; | |||
if (multi_label) | |||
{ | |||
predictions_extra_dim = tf.expand_dims(y_pred, 0); | |||
labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype: tf.@bool), 0); | |||
} | |||
else | |||
{ | |||
// Flatten predictions and labels when not multilabel. | |||
predictions_extra_dim = tf.reshape(y_pred, (1, -1)); | |||
labels_extra_dim = tf.reshape(tf.cast(y_true, dtype: tf.@bool), (1, -1)); | |||
} | |||
// Tile the thresholds for every prediction. | |||
object[] thresh_pretile_shape, thresh_tiles, data_tiles; | |||
if (multi_label) | |||
{ | |||
thresh_pretile_shape = new object[] { num_thresholds, 1, -1 }; | |||
thresh_tiles = new object[] { 1, num_predictions, thresh_label_tile }; | |||
data_tiles = new object[] { num_thresholds, 1, 1 }; | |||
} | |||
else | |||
{ | |||
thresh_pretile_shape = new object[] { num_thresholds, -1 }; | |||
thresh_tiles = new object[] { 1, num_predictions * num_labels }; | |||
data_tiles = new object[] { num_thresholds, 1 }; | |||
} | |||
var thresh_tiled = tf.tile(tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles)); | |||
// Tile the predictions for every threshold. | |||
var preds_tiled = tf.tile(predictions_extra_dim, data_tiles); | |||
// Compare predictions and threshold. | |||
var pred_is_pos = tf.greater(preds_tiled, thresh_tiled); | |||
// Tile labels by number of thresholds | |||
var label_is_pos = tf.tile(labels_extra_dim, data_tiles); | |||
Tensor weights_tiled = null; | |||
if (sample_weight != null) | |||
{ | |||
/*sample_weight = broadcast_weights( | |||
tf.cast(sample_weight, dtype: variable_dtype), y_pred);*/ | |||
weights_tiled = tf.tile( | |||
tf.reshape(sample_weight, thresh_tiles), data_tiles); | |||
} | |||
if (label_weights != null && !multi_label) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
Func<Tensor, Tensor, Tensor, IVariableV1, ITensorOrOperation> weighted_assign_add | |||
= (label, pred, weights, var) => | |||
{ | |||
var label_and_pred = tf.cast(tf.logical_and(label, pred), dtype: var.dtype); | |||
if (weights != null) | |||
{ | |||
label_and_pred *= tf.cast(weights, dtype: var.dtype); | |||
} | |||
return var.assign_add(tf.reduce_sum(label_and_pred, 1)); | |||
}; | |||
var loop_vars = new Dictionary<string, (Tensor, Tensor)> | |||
{ | |||
{ "tp", (label_is_pos, pred_is_pos) } | |||
}; | |||
var update_tn = variables_to_update.ContainsKey("tn"); | |||
var update_fp = variables_to_update.ContainsKey("fp"); | |||
var update_fn = variables_to_update.ContainsKey("fn"); | |||
Tensor pred_is_neg = null; | |||
if (update_fn || update_tn) | |||
{ | |||
pred_is_neg = tf.logical_not(pred_is_pos); | |||
loop_vars["fn"] = (label_is_pos, pred_is_neg); | |||
} | |||
if(update_fp || update_tn) | |||
{ | |||
var label_is_neg = tf.logical_not(label_is_pos); | |||
loop_vars["fp"] = (label_is_neg, pred_is_pos); | |||
if (update_tn) | |||
{ | |||
loop_vars["tn"] = (label_is_neg, pred_is_neg); | |||
} | |||
} | |||
var update_ops = new List<ITensorOrOperation>(); | |||
foreach (var matrix_cond in loop_vars.Keys) | |||
{ | |||
var (label, pred) = loop_vars[matrix_cond]; | |||
if (variables_to_update.ContainsKey(matrix_cond)) | |||
{ | |||
var op = weighted_assign_add(label, pred, weights_tiled, variables_to_update[matrix_cond]); | |||
update_ops.append(op); | |||
} | |||
} | |||
tf.group(update_ops.ToArray()); | |||
return null; | |||
} | |||
} |
@@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Utils | |||
}); | |||
} | |||
public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) | |||
public static (Tensor, Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) | |||
{ | |||
var y_pred_shape = y_pred.shape; | |||
var y_pred_rank = y_pred_shape.ndim; | |||
@@ -57,13 +57,13 @@ namespace Tensorflow.Keras.Utils | |||
if (sample_weight == null) | |||
{ | |||
return (y_pred, y_true); | |||
return (y_pred, y_true, sample_weight); | |||
} | |||
var weights_shape = sample_weight.shape; | |||
var weights_rank = weights_shape.ndim; | |||
if (weights_rank == 0) | |||
return (y_pred, sample_weight); | |||
return (y_pred, y_true, sample_weight); | |||
if (y_pred_rank > -1 && weights_rank > -1) | |||
{ | |||
@@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Utils | |||
} | |||
else | |||
{ | |||
return (y_pred, sample_weight); | |||
return (y_pred, y_true, sample_weight); | |||
} | |||
} | |||
@@ -45,4 +45,24 @@ public class MetricsTest : EagerModeTestBase | |||
var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3); | |||
Assert.AreEqual(m.numpy(), new[] { 1f, 1f }); | |||
} | |||
/// <summary> | |||
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall | |||
/// </summary> | |||
[TestMethod] | |||
public void Recall() | |||
{ | |||
var y_true = np.array(new[] { 0, 1, 1, 1 }); | |||
var y_pred = np.array(new[] { 1, 0, 1, 1 }); | |||
var m = tf.keras.metrics.Recall(); | |||
m.update_state(y_true, y_pred); | |||
var r = m.result().numpy(); | |||
Assert.AreEqual(r, 0.6666667f); | |||
m.reset_states(); | |||
var weights = np.array(new[] { 0f, 0f, 1f, 0f }); | |||
m.update_state(y_true, y_pred, sample_weight: weights); | |||
r = m.result().numpy(); | |||
Assert.AreEqual(r, 1f); | |||
} | |||
} |