Browse Source

fix keras model predict return result.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
271dcefc15
6 changed files with 128 additions and 3 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +20
    -0
      src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
  3. +21
    -0
      src/TensorFlowNET.Keras/Callbacks/History.cs
  4. +4
    -0
      src/TensorFlowNET.Keras/Callbacks/ICallback.cs
  5. +22
    -2
      src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
  6. +59
    -0
      src/TensorFlowNET.Keras/Engine/Model.Predict.cs

+ 2
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -293,7 +293,8 @@ namespace Tensorflow
// c_api.TF_CloseSession(handle, tf.Status.Handle);
if (tf.Status == null || tf.Status.Handle.IsInvalid)
{
c_api.TF_DeleteSession(handle, c_api.TF_NewStatus());
using var status = new Status();
c_api.TF_DeleteSession(handle, status.Handle);
}
else
{


+ 20
- 0
src/TensorFlowNET.Keras/Callbacks/CallbackList.cs View File

@@ -39,5 +39,25 @@ namespace Tensorflow.Keras.Callbacks
{
callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs));
}

public void on_predict_begin()
{
callbacks.ForEach(x => x.on_predict_begin());
}

public void on_predict_batch_begin(long step)
{
callbacks.ForEach(x => x.on_predict_batch_begin(step));
}

public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
{
callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs));
}

public void on_predict_end()
{
callbacks.ForEach(x => x.on_predict_end());
}
}
}

+ 21
- 0
src/TensorFlowNET.Keras/Callbacks/History.cs View File

@@ -48,5 +48,26 @@ namespace Tensorflow.Keras.Callbacks
history[log.Key].Add((float)log.Value);
}
}

public void on_predict_begin()
{
epochs = new List<int>();
history = new Dictionary<string, List<float>>();
}

public void on_predict_batch_begin(long step)
{
}

public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
{
}

public void on_predict_end()
{
}
}
}

+ 4
- 0
src/TensorFlowNET.Keras/Callbacks/ICallback.cs View File

@@ -11,5 +11,9 @@ namespace Tensorflow.Keras.Callbacks
void on_train_batch_begin(long step);
void on_train_batch_end(long end_step, Dictionary<string, float> logs);
void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs);
void on_predict_begin();
void on_predict_batch_begin(long step);
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs);
void on_predict_end();
}
}

+ 22
- 2
src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs View File

@@ -1,5 +1,4 @@
using PureHDF;
using System;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
@@ -77,5 +76,26 @@ namespace Tensorflow.Keras.Callbacks
{

}

public void on_predict_begin()
{
_reset_progbar();
_maybe_init_progbar();
}

public void on_predict_batch_begin(long step)
{
}

public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
{
}

public void on_predict_end()
{
}
}
}

+ 59
- 0
src/TensorFlowNET.Keras/Engine/Model.Predict.cs View File

@@ -5,11 +5,70 @@ using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using static Tensorflow.Binding;
using Tensorflow.Keras.Callbacks;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public Tensors predict(IDatasetV2 dataset,
int batch_size = -1,
int verbose = 0,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = dataset,
BatchSize = batch_size,
StepsPerEpoch = steps,
InitialEpoch = 0,
Epochs = 1,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});

var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Epochs = 1,
Steps = data_handler.Inferredsteps
});

Tensor batch_outputs = null;
_predict_counter.assign(0);
callbacks.on_predict_begin();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
foreach (var step in data_handler.steps())
{
callbacks.on_predict_batch_begin(step);
var tmp_batch_outputs = run_predict_step(iterator);
if (batch_outputs == null)
{
batch_outputs = tmp_batch_outputs[0];
}
else
{
batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0);
}

var end_step = step + data_handler.StepIncrement;
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
}
GC.Collect();
}
callbacks.on_predict_end();
return batch_outputs;
}

/// <summary>
/// Generates output predictions for the input samples.
/// </summary>


Loading…
Cancel
Save