diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index a71f7f39..85c262a9 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -14,6 +14,38 @@ namespace Tensorflow.Keras.Engine { public partial class Model { + protected Dictionary evaluate(CallbackList callbacks, DataHandler data_handler, bool is_val) + { + callbacks.on_test_begin(); + + //Dictionary? logs = null; + var logs = new Dictionary(); + int x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + reset_metrics(); + callbacks.on_epoch_begin(epoch); + // data_handler.catch_stop_iteration(); + + foreach (var step in data_handler.steps()) + { + callbacks.on_test_batch_begin(step); + + var data = iterator.next(); + + logs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); + tf_with(ops.control_dependencies(Array.Empty()), ctl => _test_counter.assign_add(1)); + + var end_step = step + data_handler.StepIncrement; + + if (!is_val) + callbacks.on_test_batch_end(end_step, logs); + } + } + + return logs; + } + /// /// Returns the loss value & metrics values for the model in test mode. /// @@ -64,31 +96,8 @@ namespace Tensorflow.Keras.Engine Verbose = verbose, Steps = data_handler.Inferredsteps }); - callbacks.on_test_begin(); - - //Dictionary? logs = null; - var logs = new Dictionary(); - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) - { - reset_metrics(); - // data_handler.catch_stop_iteration(); - foreach (var step in data_handler.steps()) - { - callbacks.on_test_batch_begin(step); - logs = test_function(data_handler, iterator); - var end_step = step + data_handler.StepIncrement; - if (is_val == false) - callbacks.on_test_batch_end(end_step, logs); - } - } - - var results = new Dictionary(); - foreach (var log in logs) - { - results[log.Key] = log.Value; - } - return results; + return evaluate(callbacks, data_handler, is_val); } public Dictionary evaluate(IEnumerable x, Tensor y, int verbose = 1, bool is_val = false) @@ -107,31 +116,8 @@ namespace Tensorflow.Keras.Engine Verbose = verbose, Steps = data_handler.Inferredsteps }); - callbacks.on_test_begin(); - Dictionary logs = null; - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) - { - reset_metrics(); - callbacks.on_epoch_begin(epoch); - // data_handler.catch_stop_iteration(); - - foreach (var step in data_handler.steps()) - { - callbacks.on_test_batch_begin(step); - logs = test_function(data_handler, iterator); - var end_step = step + data_handler.StepIncrement; - if (is_val == false) - callbacks.on_test_batch_end(end_step, logs); - } - } - - var results = new Dictionary(); - foreach (var log in logs) - { - results[log.Key] = log.Value; - } - return results; + return evaluate(callbacks, data_handler, is_val); } @@ -150,51 +136,8 @@ namespace Tensorflow.Keras.Engine Verbose = verbose, Steps = data_handler.Inferredsteps }); - callbacks.on_test_begin(); - - Dictionary logs = null; - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) - { - reset_metrics(); - callbacks.on_epoch_begin(epoch); - // data_handler.catch_stop_iteration(); - - foreach (var step in data_handler.steps()) - { - callbacks.on_test_batch_begin(step); - logs = test_function(data_handler, iterator); - var end_step = step + data_handler.StepIncrement; - if (is_val == false) - callbacks.on_test_batch_end(end_step, logs); - } - } - - var results = new Dictionary(); - foreach (var log in logs) - { - results[log.Key] = log.Value; - } - return results; - } - - Dictionary test_function(DataHandler data_handler, OwnedIterator iterator) - { - var data = iterator.next(); - var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; - var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); - tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); - return outputs; - } - - Dictionary test_step(DataHandler data_handler, Tensor x, Tensor y) - { - (x, y) = data_handler.DataAdapter.Expand1d(x, y); - var y_pred = Apply(x, training: false); - var loss = compiled_loss.Call(y, y_pred); - - compiled_metrics.update_state(y, y_pred); - return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2); + return evaluate(callbacks, data_handler, is_val); } } -} +} \ No newline at end of file