diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 503ee131..b3607eb1 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -49,6 +49,12 @@ namespace Tensorflow public static void add(this IList list, T element) => list.Add(element); + public static void add(this IList list, IEnumerable elements) + { + foreach (var ele in elements) + list.Add(ele); + } + public static void append(this IList list, T element) => list.Insert(list.Count, element); diff --git a/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs b/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs index ff34ca8e..974aa6ca 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs @@ -11,7 +11,6 @@ namespace Tensorflow.Keras.Engine Mean _loss_metric; bool _built; Tensor[] _per_output_metrics; - List loss_metric_values; public LossesContainer(ILossFunc losses, string[] output_names = null) : base(output_names) @@ -19,7 +18,6 @@ namespace Tensorflow.Keras.Engine _user_losses = losses; _losses = losses; _loss_metric = new Mean(name: "loss"); - loss_metric_values = new List(); _built = false; } @@ -37,6 +35,7 @@ namespace Tensorflow.Keras.Engine var batch_dim = array_ops.shape(y_true)[0]; var loss_values = new List(); + var loss_metric_values = new List(); /*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE || _losses.Reduction == ReductionV2.AUTO) @@ -69,5 +68,16 @@ namespace Tensorflow.Keras.Engine { // _per_output_metrics = _output_names.Select(x => null); } + + public IEnumerable metrics + { + get + { + if (!_built) + return new List(); + + return new[] { _loss_metric }; + } + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs index b8d734a9..990f15ca 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs @@ -1,29 +1,81 @@ -namespace Tensorflow.Keras.Engine +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Metrics; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine { public class MetricsContainer : Container { string[] _user_metrics; - string[] _metrics; + string[] _metric_names; + Metric[] _metrics; + List _metrics_in_order; public MetricsContainer(string[] metrics, string[] output_names = null) : base(output_names) { _user_metrics = metrics; - _metrics = metrics; + _metric_names = metrics; _built = false; } public void update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) { if (!_built) - Build(); + Build(y_true, y_pred); + + foreach (var metric_obj in _metrics_in_order) + metric_obj.update_state(y_true, y_pred); + } + void Build(Tensor y_true, Tensor y_pred) + { + _metrics = _get_metric_objects(_metric_names, y_true, y_pred); + _set_metric_names(); + _create_ordered_metrics(); _built = true; } - void Build() + void _set_metric_names() + { + + } + + void _create_ordered_metrics() + { + _metrics_in_order = new List(); + foreach (var m in _metrics) + _metrics_in_order.append(m); + } + + Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p) + { + return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray(); + } + + Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p) + { + Func metric_obj = null; + if (metric == "accuracy" || metric == "acc") + { + metric_obj = tf.keras.metrics.sparse_categorical_accuracy; + return new MeanMetricWrapper(metric_obj, metric); + } + + throw new NotImplementedException(""); + } + + public IEnumerable metrics { + get + { + if (!_built) + return new List(); + return _metrics_in_order; + } } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs index 7eb75d1e..f549e61d 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs @@ -1,5 +1,7 @@ using NumSharp; using System; +using System.Collections.Generic; +using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; @@ -51,17 +53,24 @@ namespace Tensorflow.Keras.Engine stop_training = false; _train_counter.assign(0); - + bool first_step = true; foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { // reset_metrics(); // callbacks.on_epoch_begin(epoch) // data_handler.catch_stop_iteration(); + IEnumerable<(string, Tensor)> results = null; foreach (var step in data_handler.steps()) { // callbacks.on_train_batch_begin(step) - step_function(iterator); + results = step_function(iterator); + if (first_step) + { + Console.WriteLine($"epoch: {epoch}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); + first_step = false; + } } + Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Metrics.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Metrics.cs new file mode 100644 index 00000000..f183cb52 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Metrics.cs @@ -0,0 +1,31 @@ +using System.Collections.Generic; +using Tensorflow.Keras.Metrics; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + public IEnumerable metrics + { + get + { + var _metrics = new List(); + if (_is_compiled) + { + if (compiled_loss != null) + _metrics.add(compiled_loss.metrics); + if (compiled_metrics != null) + _metrics.add(compiled_metrics.metrics); + } + + foreach(var layer in _flatten_layers()) + { + // _metrics.extend(layer.metrics); + } + + return _metrics; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs index 6b9caa55..d8aff1e1 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs @@ -8,12 +8,12 @@ namespace Tensorflow.Keras.Engine { public partial class Model { - Tensor step_function(OwnedIterator iterator) + IEnumerable<(string, Tensor)> step_function(OwnedIterator iterator) { var data = iterator.next(); var outputs = train_step(data[0], data[1]); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); - return null; + return outputs; } /// @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - IEnumerable<(string, Tensor)> train_step(Tensor x, Tensor y) + List<(string, Tensor)> train_step(Tensor x, Tensor y) { (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); @@ -36,7 +36,8 @@ namespace Tensorflow.Keras.Engine // such as loss scaling and gradient clipping. _minimize(tape, optimizer, loss, trainable_variables); compiled_metrics.update_state(y, y_pred); - return new[] { ("loss", loss) }; + + return metrics.Select(x => (x.Name, x.result())).ToList(); } void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List trainable_variables) diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs index 7afebfac..5d2f15af 100644 --- a/src/TensorFlowNET.Core/Keras/KerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -5,6 +5,7 @@ using Tensorflow.Keras.Datasets; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Optimizers; namespace Tensorflow @@ -20,6 +21,7 @@ namespace Tensorflow public Preprocessing preprocessing { get; } = new Preprocessing(); public BackendImpl backend { get; } = new BackendImpl(); public OptimizerApi optimizers { get; } = new OptimizerApi(); + public MetricsApi metrics { get; } = new MetricsApi(); public Sequential Sequential(List layers = null, string name = null) diff --git a/src/TensorFlowNET.Core/Keras/Metrics/Mean.cs b/src/TensorFlowNET.Core/Keras/Metrics/Mean.cs index d3147fae..8a55690b 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/Mean.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/Mean.cs @@ -5,7 +5,7 @@ /// public class Mean : Reduce { - public Mean(string name = "mean", TF_DataType dtype = TF_DataType.DtInvalid) + public Mean(string name = "mean", TF_DataType dtype = TF_DataType.TF_FLOAT) : base(Reduction.WEIGHTED_MEAN, name, dtype: dtype) { diff --git a/src/TensorFlowNET.Core/Keras/Metrics/MeanMetricWrapper.cs b/src/TensorFlowNET.Core/Keras/Metrics/MeanMetricWrapper.cs new file mode 100644 index 00000000..3bdfe6d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Metrics/MeanMetricWrapper.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Metrics +{ + public class MeanMetricWrapper : Mean + { + string name; + Func _fn = null; + + public MeanMetricWrapper(Func fn, string name, TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(name: name, dtype: dtype) + { + _fn = fn; + } + + public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + y_true = math_ops.cast(y_true, _dtype); + y_pred = math_ops.cast(y_pred, _dtype); + + var matches = _fn(y_true, y_pred); + return update_state(matches, sample_weight: sample_weight); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Metrics/Metric.cs b/src/TensorFlowNET.Core/Keras/Metrics/Metric.cs index 59cab031..9cbaaeb7 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/Metric.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/Metric.cs @@ -10,6 +10,11 @@ namespace Tensorflow.Keras.Metrics /// public class Metric : Layer { + protected IVariableV1 total; + protected IVariableV1 count; + protected string _reduction; + protected TF_DataType _dtype; + public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid) : base(new LayerArgs { @@ -44,5 +49,14 @@ namespace Tensorflow.Keras.Metrics aggregation: aggregation); }); } + + public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + => throw new NotImplementedException(""); + + public virtual Tensor result() + => throw new NotImplementedException(""); + + public override string ToString() + => $"{name} {(float)total.numpy()}/{(float)count.numpy()}"; } } diff --git a/src/TensorFlowNET.Core/Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/MetricsApi.cs new file mode 100644 index 00000000..06a73edb --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Metrics/MetricsApi.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Metrics +{ + public class MetricsApi + { + /// + /// Calculates how often predictions matches integer labels. + /// + /// Integer ground truth values. + /// The prediction values. + /// Sparse categorical accuracy values. + public Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred) + { + var y_pred_rank = y_pred.TensorShape.ndim; + var y_true_rank = y_true.TensorShape.ndim; + // If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) + if (y_true_rank != -1 && y_pred_rank != -1 + && y_true.shape.Length == y_pred.shape.Length) + y_true = array_ops.squeeze(y_true, axis: new[] { -1 }); + y_pred = math_ops.argmax(y_pred, -1); + + // If the predicted output and actual output types don't match, force cast them + // to match. + if (y_pred.dtype != y_true.dtype) + y_pred = math_ops.cast(y_pred, y_true.dtype); + + return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs b/src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs index 15dcbf4b..f7cdb8f5 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs @@ -9,11 +9,6 @@ namespace Tensorflow.Keras.Metrics /// public class Reduce : Metric { - IVariableV1 total; - IVariableV1 count; - string _reduction; - TF_DataType _dtype; - public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) : base(name: name, dtype: dtype) { @@ -43,11 +38,18 @@ namespace Tensorflow.Keras.Metrics var value_sum = math_ops.reduce_sum(values); tf_with(ops.control_dependencies(new[] { value_sum }), ctl => { - var update_total_op = total.assign_add(value_sum); + update_total_op = total.assign_add(value_sum); }); + // Exit early if the reduction doesn't have a denominator. + if (_reduction == Reduction.SUM) + return update_total_op; + + // Update `count` for reductions that require a denominator. Tensor num_values = null; - if (_reduction == ReductionV2.WEIGHTED_MEAN) + if (_reduction == Reduction.SUM_OVER_BATCH_SIZE) + num_values = math_ops.cast(array_ops.size(values), _dtype); + else if (_reduction == ReductionV2.WEIGHTED_MEAN) { if (sample_weight == null) num_values = math_ops.cast(array_ops.size(values), _dtype); @@ -58,5 +60,15 @@ namespace Tensorflow.Keras.Metrics return tf_with(ops.control_dependencies(new[] { update_total_op }), ctl => count.assign_add(num_values)); } + + public override Tensor result() + { + if (_reduction == Reduction.SUM) + return array_ops.identity(total.AsTensor()); + else if (_reduction == Reduction.WEIGHTED_MEAN || _reduction == Reduction.SUM_OVER_BATCH_SIZE) + return math_ops.div_no_nan(total.AsTensor(), count.AsTensor()); + + return base.result(); + } } } diff --git a/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs b/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs index 0a93ae92..bef48546 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs @@ -3,6 +3,7 @@ public class Reduction { public const string NONE = "none"; + public const string SUM = "sum"; public const string WEIGHTED_SUM = "weighted_sum"; public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"; public const string WEIGHTED_MEAN = "weighted_mean"; diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index e5b29ea2..17728582 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -68,6 +68,9 @@ namespace Tensorflow return gen_math_ops.add_n(inputs, name: name); } + public static Tensor argmax(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + => gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name); + public static Tensor round(Tensor x, string name = null) { x = ops.convert_to_tensor(x, name: "x");