|
|
@@ -5,11 +5,70 @@ using System.Linq; |
|
|
|
using Tensorflow.Keras.ArgsDefinition; |
|
|
|
using Tensorflow.Keras.Engine.DataAdapters; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
using Tensorflow.Keras.Callbacks; |
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Engine |
|
|
|
{ |
|
|
|
public partial class Model |
|
|
|
{ |
|
|
|
public Tensors predict(IDatasetV2 dataset, |
|
|
|
int batch_size = -1, |
|
|
|
int verbose = 0, |
|
|
|
int steps = -1, |
|
|
|
int max_queue_size = 10, |
|
|
|
int workers = 1, |
|
|
|
bool use_multiprocessing = false) |
|
|
|
{ |
|
|
|
var data_handler = new DataHandler(new DataHandlerArgs |
|
|
|
{ |
|
|
|
Dataset = dataset, |
|
|
|
BatchSize = batch_size, |
|
|
|
StepsPerEpoch = steps, |
|
|
|
InitialEpoch = 0, |
|
|
|
Epochs = 1, |
|
|
|
MaxQueueSize = max_queue_size, |
|
|
|
Workers = workers, |
|
|
|
UseMultiprocessing = use_multiprocessing, |
|
|
|
Model = this, |
|
|
|
StepsPerExecution = _steps_per_execution |
|
|
|
}); |
|
|
|
|
|
|
|
var callbacks = new CallbackList(new CallbackParams |
|
|
|
{ |
|
|
|
Model = this, |
|
|
|
Verbose = verbose, |
|
|
|
Epochs = 1, |
|
|
|
Steps = data_handler.Inferredsteps |
|
|
|
}); |
|
|
|
|
|
|
|
Tensor batch_outputs = null; |
|
|
|
_predict_counter.assign(0); |
|
|
|
callbacks.on_predict_begin(); |
|
|
|
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) |
|
|
|
{ |
|
|
|
foreach (var step in data_handler.steps()) |
|
|
|
{ |
|
|
|
callbacks.on_predict_batch_begin(step); |
|
|
|
var tmp_batch_outputs = run_predict_step(iterator); |
|
|
|
if (batch_outputs == null) |
|
|
|
{ |
|
|
|
batch_outputs = tmp_batch_outputs[0]; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0); |
|
|
|
} |
|
|
|
|
|
|
|
var end_step = step + data_handler.StepIncrement; |
|
|
|
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } }); |
|
|
|
} |
|
|
|
GC.Collect(); |
|
|
|
} |
|
|
|
|
|
|
|
callbacks.on_predict_end(); |
|
|
|
return batch_outputs; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Generates output predictions for the input samples. |
|
|
|
/// </summary> |
|
|
|