From 271dcefc15c5f5b5170c00304820458b5cfa8de3 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Sun, 5 Feb 2023 12:46:31 -0600 Subject: [PATCH] fix keras model predict return result. --- .../Sessions/BaseSession.cs | 3 +- .../Callbacks/CallbackList.cs | 20 +++++++ src/TensorFlowNET.Keras/Callbacks/History.cs | 21 +++++++ .../Callbacks/ICallback.cs | 4 ++ .../Callbacks/ProgbarLogger.cs | 24 +++++++- .../Engine/Model.Predict.cs | 59 +++++++++++++++++++ 6 files changed, 128 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 01ba0407..095187b9 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -293,7 +293,8 @@ namespace Tensorflow // c_api.TF_CloseSession(handle, tf.Status.Handle); if (tf.Status == null || tf.Status.Handle.IsInvalid) { - c_api.TF_DeleteSession(handle, c_api.TF_NewStatus()); + using var status = new Status(); + c_api.TF_DeleteSession(handle, status.Handle); } else { diff --git a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs index bb3ed6ed..54e3780a 100644 --- a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs +++ b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs @@ -39,5 +39,25 @@ namespace Tensorflow.Keras.Callbacks { callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs)); } + + public void on_predict_begin() + { + callbacks.ForEach(x => x.on_predict_begin()); + } + + public void on_predict_batch_begin(long step) + { + callbacks.ForEach(x => x.on_predict_batch_begin(step)); + } + + public void on_predict_batch_end(long end_step, Dictionary logs) + { + callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs)); + } + + public void on_predict_end() + { + callbacks.ForEach(x => x.on_predict_end()); + } } } diff --git a/src/TensorFlowNET.Keras/Callbacks/History.cs b/src/TensorFlowNET.Keras/Callbacks/History.cs index 02588b5e..89e1834b 100644 --- a/src/TensorFlowNET.Keras/Callbacks/History.cs +++ b/src/TensorFlowNET.Keras/Callbacks/History.cs @@ -48,5 +48,26 @@ namespace Tensorflow.Keras.Callbacks history[log.Key].Add((float)log.Value); } } + + public void on_predict_begin() + { + epochs = new List(); + history = new Dictionary>(); + } + + public void on_predict_batch_begin(long step) + { + + } + + public void on_predict_batch_end(long end_step, Dictionary logs) + { + + } + + public void on_predict_end() + { + + } } } diff --git a/src/TensorFlowNET.Keras/Callbacks/ICallback.cs b/src/TensorFlowNET.Keras/Callbacks/ICallback.cs index 34763c55..7d71ccac 100644 --- a/src/TensorFlowNET.Keras/Callbacks/ICallback.cs +++ b/src/TensorFlowNET.Keras/Callbacks/ICallback.cs @@ -11,5 +11,9 @@ namespace Tensorflow.Keras.Callbacks void on_train_batch_begin(long step); void on_train_batch_end(long end_step, Dictionary logs); void on_epoch_end(int epoch, Dictionary epoch_logs); + void on_predict_begin(); + void on_predict_batch_begin(long step); + void on_predict_batch_end(long end_step, Dictionary logs); + void on_predict_end(); } } diff --git a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs index 17e04101..bb18b2cb 100644 --- a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs +++ b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs @@ -1,5 +1,4 @@ -using PureHDF; -using System; +using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; @@ -77,5 +76,26 @@ namespace Tensorflow.Keras.Callbacks { } + + public void on_predict_begin() + { + _reset_progbar(); + _maybe_init_progbar(); + } + + public void on_predict_batch_begin(long step) + { + + } + + public void on_predict_batch_end(long end_step, Dictionary logs) + { + + } + + public void on_predict_end() + { + + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs index 6dbce98c..4d5755b0 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs @@ -5,11 +5,70 @@ using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; using static Tensorflow.Binding; +using Tensorflow.Keras.Callbacks; namespace Tensorflow.Keras.Engine { public partial class Model { + public Tensors predict(IDatasetV2 dataset, + int batch_size = -1, + int verbose = 0, + int steps = -1, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + var data_handler = new DataHandler(new DataHandlerArgs + { + Dataset = dataset, + BatchSize = batch_size, + StepsPerEpoch = steps, + InitialEpoch = 0, + Epochs = 1, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + var callbacks = new CallbackList(new CallbackParams + { + Model = this, + Verbose = verbose, + Epochs = 1, + Steps = data_handler.Inferredsteps + }); + + Tensor batch_outputs = null; + _predict_counter.assign(0); + callbacks.on_predict_begin(); + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + foreach (var step in data_handler.steps()) + { + callbacks.on_predict_batch_begin(step); + var tmp_batch_outputs = run_predict_step(iterator); + if (batch_outputs == null) + { + batch_outputs = tmp_batch_outputs[0]; + } + else + { + batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0); + } + + var end_step = step + data_handler.StepIncrement; + callbacks.on_predict_batch_end(end_step, new Dictionary { { "outputs", batch_outputs } }); + } + GC.Collect(); + } + + callbacks.on_predict_end(); + return batch_outputs; + } + /// /// Generates output predictions for the input samples. ///