diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs index 19f3df9b..ddc72aee 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -60,7 +60,7 @@ public interface IModel : ILayer bool skip_mismatch = false, object options = null); - Dictionary evaluate(NDArray x, NDArray y, + Dictionary evaluate(Tensor x, Tensor y, int batch_size = -1, int verbose = 1, int steps = -1, diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 185de4f4..a71f7f39 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - public Dictionary evaluate(NDArray x, NDArray y, + public Dictionary evaluate(Tensor x, Tensor y, int batch_size = -1, int verbose = 1, int steps = -1, @@ -91,7 +91,7 @@ namespace Tensorflow.Keras.Engine return results; } - public Dictionary evaluate(IEnumerable x, NDArray y, int verbose = 1, bool is_val = false) + public Dictionary evaluate(IEnumerable x, Tensor y, int verbose = 1, bool is_val = false) { var data_handler = new DataHandler(new DataHandlerArgs { @@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Engine foreach (var step in data_handler.steps()) { callbacks.on_test_batch_begin(step); - logs = test_step_multi_inputs_function(data_handler, iterator); + logs = test_function(data_handler, iterator); var end_step = step + data_handler.StepIncrement; if (is_val == false) callbacks.on_test_batch_end(end_step, logs); @@ -178,20 +178,14 @@ namespace Tensorflow.Keras.Engine } Dictionary test_function(DataHandler data_handler, OwnedIterator iterator) - { - var data = iterator.next(); - var outputs = test_step(data_handler, data[0], data[1]); - tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); - return outputs; - } - Dictionary test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) { var data = iterator.next(); var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); - tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); + tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); return outputs; } + Dictionary test_step(DataHandler data_handler, Tensor x, Tensor y) { (x, y) = data_handler.DataAdapter.Expand1d(x, y); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index bb8e18cc..17ecde98 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -266,7 +266,7 @@ namespace Tensorflow.Keras.Engine { // Because evaluate calls call_test_batch_end, this interferes with our output on the screen // so we need to pass a is_val parameter to stop on_test_batch_end - var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); + var val_logs = evaluate((Tensor)validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); foreach (var log in val_logs) { logs["val_" + log.Key] = log.Value;