Browse Source

Update Model.Evaluate.cs

Fix my bad:

Bad handling between test_function and test_step_multi_inputs_function.
pull/1092/head
Luc Bologna GitHub 2 years ago
parent
commit
340fcd3562
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 75 additions and 41 deletions
  1. +75
    -41
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+ 75
- 41
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -1,51 +1,19 @@
using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Keras.Engine.DataAdapters;
using static Tensorflow.Binding;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using Tensorflow;
using Tensorflow.Keras.Callbacks;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
protected Dictionary<string, float> evaluate(CallbackList callbacks, DataHandler data_handler, bool is_val)
{
callbacks.on_test_begin();

//Dictionary<string, float>? logs = null;
var logs = new Dictionary<string, float>();
int x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);

var data = iterator.next();

logs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _test_counter.assign_add(1));

var end_step = step + data_handler.StepIncrement;

if (!is_val)
callbacks.on_test_batch_end(end_step, logs);
}
}

return logs;
}

/// <summary>
/// Returns the loss value & metrics values for the model in test mode.
/// </summary>
@@ -97,7 +65,7 @@ namespace Tensorflow.Keras.Engine
Steps = data_handler.Inferredsteps
});

return evaluate(callbacks, data_handler, is_val);
return evaluate(data_handler, callbacks, is_val, test_function);
}

public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
@@ -117,10 +85,9 @@ namespace Tensorflow.Keras.Engine
Steps = data_handler.Inferredsteps
});

return evaluate(callbacks, data_handler, is_val);
return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function);
}


public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
@@ -137,7 +104,74 @@ namespace Tensorflow.Keras.Engine
Steps = data_handler.Inferredsteps
});

return evaluate(callbacks, data_handler, is_val);
return evaluate(data_handler, callbacks, is_val, test_function);
}

/// <summary>
/// Internal bare implementation of evaluate function.
/// </summary>
/// <param name="data_handler">Interations handling objects</param>
/// <param name="callbacks"></param>
/// <param name="test_func">The function to be called on each batch of data.</param>
/// <param name="is_val">Whether it is validation or test.</param>
/// <returns></returns>
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, Tensor[], Dictionary<string, float>> test_func)
{
callbacks.on_test_begin();

var results = new Dictionary<string, float>();
var logs = results;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);

var data = iterator.next();

logs = test_func(data_handler, iterator.next());

tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _train_counter.assign_add(1));

var end_step = step + data_handler.StepIncrement;
if (!is_val)
callbacks.on_test_batch_end(end_step, logs);
}

if (!is_val)
callbacks.on_epoch_end(epoch, logs);
}

foreach (var log in logs)
{
results[log.Key] = log.Value;
}

return results;
}

Dictionary<string, float> test_function(DataHandler data_handler, Tensor[] data)
{
var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]);

var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);

compiled_metrics.update_state(y, y_pred);

var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2);
return outputs;
}

Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data)
{
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
return outputs;
}
}
}
}

Loading…
Cancel
Save