Browse Source

Refactor: Change Model evaluate

IModel.Dictionary<string, float> evaluate(NDArray, NDArray, ...) is now IModel.Dictionary<string, float> evaluate(Tensor, Tensor, ...)
Merge Model.Evaluate.test_step_multi_inputs_function(...) and Model.Evaluate.test_function(...)

Note: An internal function need to add an explicit cast in Tensor
tags/v0.110.0-LSTM-Model
Luc BOLOGNA Haiping 2 years ago
parent
commit
02cb239c5f
3 changed files with 7 additions and 13 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/IModel.cs
  2. +5
    -11
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  3. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/IModel.cs View File

@@ -60,7 +60,7 @@ public interface IModel : ILayer
bool skip_mismatch = false,
object options = null);

Dictionary<string, float> evaluate(NDArray x, NDArray y,
Dictionary<string, float> evaluate(Tensor x, Tensor y,
int batch_size = -1,
int verbose = 1,
int steps = -1,


+ 5
- 11
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="use_multiprocessing"></param>
/// <param name="return_dict"></param>
/// <param name="is_val"></param>
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
public Dictionary<string, float> evaluate(Tensor x, Tensor y,
int batch_size = -1,
int verbose = 1,
int steps = -1,
@@ -91,7 +91,7 @@ namespace Tensorflow.Keras.Engine
return results;
}

public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false)
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
@@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Engine
foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);
logs = test_step_multi_inputs_function(data_handler, iterator);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
if (is_val == false)
callbacks.on_test_batch_end(end_step, logs);
@@ -178,20 +178,14 @@ namespace Tensorflow.Keras.Engine
}

Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var outputs = test_step(data_handler, data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}

Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -266,7 +266,7 @@ namespace Tensorflow.Keras.Engine
{
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
var val_logs = evaluate((Tensor)validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;


Loading…
Cancel
Save