From 0effee430c905f7ee84a064a4b1474ef931368a0 Mon Sep 17 00:00:00 2001 From: Luc Bologna Date: Mon, 5 Jun 2023 20:14:57 +0200 Subject: [PATCH] Update Model.Evaluate.cs Fix my bad: Bad handling between test_function and test_step_multi_inputs_function. --- .../Engine/Model.Evaluate.cs | 116 +++++++++++------- 1 file changed, 75 insertions(+), 41 deletions(-) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 85c262a9..99a891c0 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -1,51 +1,19 @@ -using Tensorflow.NumPy; using System; using System.Collections.Generic; using System.Linq; +using Tensorflow; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Callbacks; using Tensorflow.Keras.Engine.DataAdapters; -using static Tensorflow.Binding; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; -using Tensorflow; -using Tensorflow.Keras.Callbacks; +using Tensorflow.NumPy; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { public partial class Model { - protected Dictionary evaluate(CallbackList callbacks, DataHandler data_handler, bool is_val) - { - callbacks.on_test_begin(); - - //Dictionary? logs = null; - var logs = new Dictionary(); - int x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) - { - reset_metrics(); - callbacks.on_epoch_begin(epoch); - // data_handler.catch_stop_iteration(); - - foreach (var step in data_handler.steps()) - { - callbacks.on_test_batch_begin(step); - - var data = iterator.next(); - - logs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); - tf_with(ops.control_dependencies(Array.Empty()), ctl => _test_counter.assign_add(1)); - - var end_step = step + data_handler.StepIncrement; - - if (!is_val) - callbacks.on_test_batch_end(end_step, logs); - } - } - - return logs; - } - /// /// Returns the loss value & metrics values for the model in test mode. /// @@ -97,7 +65,7 @@ namespace Tensorflow.Keras.Engine Steps = data_handler.Inferredsteps }); - return evaluate(callbacks, data_handler, is_val); + return evaluate(data_handler, callbacks, is_val, test_function); } public Dictionary evaluate(IEnumerable x, Tensor y, int verbose = 1, bool is_val = false) @@ -117,10 +85,9 @@ namespace Tensorflow.Keras.Engine Steps = data_handler.Inferredsteps }); - return evaluate(callbacks, data_handler, is_val); + return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function); } - public Dictionary evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false) { var data_handler = new DataHandler(new DataHandlerArgs @@ -137,7 +104,74 @@ namespace Tensorflow.Keras.Engine Steps = data_handler.Inferredsteps }); - return evaluate(callbacks, data_handler, is_val); + return evaluate(data_handler, callbacks, is_val, test_function); + } + + /// + /// Internal bare implementation of evaluate function. + /// + /// Interations handling objects + /// + /// The function to be called on each batch of data. + /// Whether it is validation or test. + /// + Dictionary evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func> test_func) + { + callbacks.on_test_begin(); + + var results = new Dictionary(); + var logs = results; + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + reset_metrics(); + callbacks.on_epoch_begin(epoch); + // data_handler.catch_stop_iteration(); + + foreach (var step in data_handler.steps()) + { + callbacks.on_test_batch_begin(step); + + var data = iterator.next(); + + logs = test_func(data_handler, iterator.next()); + + tf_with(ops.control_dependencies(Array.Empty()), ctl => _train_counter.assign_add(1)); + + var end_step = step + data_handler.StepIncrement; + if (!is_val) + callbacks.on_test_batch_end(end_step, logs); + } + + if (!is_val) + callbacks.on_epoch_end(epoch, logs); + } + + foreach (var log in logs) + { + results[log.Key] = log.Value; + } + + return results; + } + + Dictionary test_function(DataHandler data_handler, Tensor[] data) + { + var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]); + + var y_pred = Apply(x, training: false); + var loss = compiled_loss.Call(y, y_pred); + + compiled_metrics.update_state(y, y_pred); + + var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2); + return outputs; + } + + Dictionary test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data) + { + var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; + var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); + return outputs; } } -} \ No newline at end of file +}