using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using static Tensorflow.Binding;
using Tensorflow.Keras.Callbacks;
namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public Tensors predict(IDatasetV2 dataset,
int batch_size = -1,
int verbose = 0,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = dataset,
BatchSize = batch_size,
StepsPerEpoch = steps,
InitialEpoch = 0,
Epochs = 1,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});
return PredictInternal(data_handler, verbose);
}
///
/// Generates output predictions for the input samples.
///
/// Input samples
/// Number of samples per batch
/// Verbosity mode
///
/// Total number of steps (batches of samples)
/// before declaring the prediction round finished.
///
///
///
///
///
public Tensors predict(Tensors x,
int batch_size = -1,
int verbose = 0,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
X = x,
BatchSize = batch_size,
StepsPerEpoch = steps,
InitialEpoch = 0,
Epochs = 1,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});
return PredictInternal(data_handler, verbose);
}
Tensors PredictInternal(DataHandler data_handler, int verbose)
{
var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Epochs = 1,
Steps = data_handler.Inferredsteps
});
Tensors batch_outputs = null;
_predict_counter.assign(0);
callbacks.on_predict_begin();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
foreach (var step in data_handler.steps())
{
callbacks.on_predict_batch_begin(step);
var tmp_batch_outputs = run_predict_step(iterator);
if (batch_outputs == null)
{
batch_outputs = tmp_batch_outputs;
}
else
{
for (int i = 0; i < batch_outputs.Length; i++)
batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0);
}
var end_step = step + data_handler.StepIncrement;
callbacks.on_predict_batch_end(end_step, new Dictionary { { "outputs", batch_outputs } });
GC.Collect();
}
}
callbacks.on_predict_end();
return batch_outputs;
}
Tensors run_predict_step(OwnedIterator iterator)
{
var data = iterator.next();
var outputs = predict_step(data);
tf_with(ops.control_dependencies(Array.Empty