@@ -293,7 +293,8 @@ namespace Tensorflow | |||||
// c_api.TF_CloseSession(handle, tf.Status.Handle); | // c_api.TF_CloseSession(handle, tf.Status.Handle); | ||||
if (tf.Status == null || tf.Status.Handle.IsInvalid) | if (tf.Status == null || tf.Status.Handle.IsInvalid) | ||||
{ | { | ||||
c_api.TF_DeleteSession(handle, c_api.TF_NewStatus()); | |||||
using var status = new Status(); | |||||
c_api.TF_DeleteSession(handle, status.Handle); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -39,5 +39,25 @@ namespace Tensorflow.Keras.Callbacks | |||||
{ | { | ||||
callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs)); | callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs)); | ||||
} | } | ||||
public void on_predict_begin() | |||||
{ | |||||
callbacks.ForEach(x => x.on_predict_begin()); | |||||
} | |||||
public void on_predict_batch_begin(long step) | |||||
{ | |||||
callbacks.ForEach(x => x.on_predict_batch_begin(step)); | |||||
} | |||||
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | |||||
{ | |||||
callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs)); | |||||
} | |||||
public void on_predict_end() | |||||
{ | |||||
callbacks.ForEach(x => x.on_predict_end()); | |||||
} | |||||
} | } | ||||
} | } |
@@ -48,5 +48,26 @@ namespace Tensorflow.Keras.Callbacks | |||||
history[log.Key].Add((float)log.Value); | history[log.Key].Add((float)log.Value); | ||||
} | } | ||||
} | } | ||||
public void on_predict_begin() | |||||
{ | |||||
epochs = new List<int>(); | |||||
history = new Dictionary<string, List<float>>(); | |||||
} | |||||
public void on_predict_batch_begin(long step) | |||||
{ | |||||
} | |||||
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | |||||
{ | |||||
} | |||||
public void on_predict_end() | |||||
{ | |||||
} | |||||
} | } | ||||
} | } |
@@ -11,5 +11,9 @@ namespace Tensorflow.Keras.Callbacks | |||||
void on_train_batch_begin(long step); | void on_train_batch_begin(long step); | ||||
void on_train_batch_end(long end_step, Dictionary<string, float> logs); | void on_train_batch_end(long end_step, Dictionary<string, float> logs); | ||||
void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs); | void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs); | ||||
void on_predict_begin(); | |||||
void on_predict_batch_begin(long step); | |||||
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | |||||
void on_predict_end(); | |||||
} | } | ||||
} | } |
@@ -1,5 +1,4 @@ | |||||
using PureHDF; | |||||
using System; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Linq; | using System.Linq; | ||||
@@ -77,5 +76,26 @@ namespace Tensorflow.Keras.Callbacks | |||||
{ | { | ||||
} | } | ||||
public void on_predict_begin() | |||||
{ | |||||
_reset_progbar(); | |||||
_maybe_init_progbar(); | |||||
} | |||||
public void on_predict_batch_begin(long step) | |||||
{ | |||||
} | |||||
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | |||||
{ | |||||
} | |||||
public void on_predict_end() | |||||
{ | |||||
} | |||||
} | } | ||||
} | } |
@@ -5,11 +5,70 @@ using System.Linq; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow.Keras.Callbacks; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
public partial class Model | public partial class Model | ||||
{ | { | ||||
public Tensors predict(IDatasetV2 dataset, | |||||
int batch_size = -1, | |||||
int verbose = 0, | |||||
int steps = -1, | |||||
int max_queue_size = 10, | |||||
int workers = 1, | |||||
bool use_multiprocessing = false) | |||||
{ | |||||
var data_handler = new DataHandler(new DataHandlerArgs | |||||
{ | |||||
Dataset = dataset, | |||||
BatchSize = batch_size, | |||||
StepsPerEpoch = steps, | |||||
InitialEpoch = 0, | |||||
Epochs = 1, | |||||
MaxQueueSize = max_queue_size, | |||||
Workers = workers, | |||||
UseMultiprocessing = use_multiprocessing, | |||||
Model = this, | |||||
StepsPerExecution = _steps_per_execution | |||||
}); | |||||
var callbacks = new CallbackList(new CallbackParams | |||||
{ | |||||
Model = this, | |||||
Verbose = verbose, | |||||
Epochs = 1, | |||||
Steps = data_handler.Inferredsteps | |||||
}); | |||||
Tensor batch_outputs = null; | |||||
_predict_counter.assign(0); | |||||
callbacks.on_predict_begin(); | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
{ | |||||
foreach (var step in data_handler.steps()) | |||||
{ | |||||
callbacks.on_predict_batch_begin(step); | |||||
var tmp_batch_outputs = run_predict_step(iterator); | |||||
if (batch_outputs == null) | |||||
{ | |||||
batch_outputs = tmp_batch_outputs[0]; | |||||
} | |||||
else | |||||
{ | |||||
batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0); | |||||
} | |||||
var end_step = step + data_handler.StepIncrement; | |||||
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } }); | |||||
} | |||||
GC.Collect(); | |||||
} | |||||
callbacks.on_predict_end(); | |||||
return batch_outputs; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Generates output predictions for the input samples. | /// Generates output predictions for the input samples. | ||||
/// </summary> | /// </summary> | ||||