Browse Source

Refactor: Model.Evaluate.cs

tags/v0.110.0-LSTM-Model
Luc BOLOGNA Haiping 2 years ago
parent
commit
f7208c9494
1 changed files with 36 additions and 93 deletions
  1. +36
    -93
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+ 36
- 93
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -14,6 +14,38 @@ namespace Tensorflow.Keras.Engine
{ {
public partial class Model public partial class Model
{ {
protected Dictionary<string, float> evaluate(CallbackList callbacks, DataHandler data_handler, bool is_val)
{
callbacks.on_test_begin();

//Dictionary<string, float>? logs = null;
var logs = new Dictionary<string, float>();
int x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);

var data = iterator.next();

logs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _test_counter.assign_add(1));

var end_step = step + data_handler.StepIncrement;

if (!is_val)
callbacks.on_test_batch_end(end_step, logs);
}
}

return logs;
}

/// <summary> /// <summary>
/// Returns the loss value & metrics values for the model in test mode. /// Returns the loss value & metrics values for the model in test mode.
/// </summary> /// </summary>
@@ -64,31 +96,8 @@ namespace Tensorflow.Keras.Engine
Verbose = verbose, Verbose = verbose,
Steps = data_handler.Inferredsteps Steps = data_handler.Inferredsteps
}); });
callbacks.on_test_begin();

//Dictionary<string, float>? logs = null;
var logs = new Dictionary<string, float>();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
// data_handler.catch_stop_iteration();


foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
if (is_val == false)
callbacks.on_test_batch_end(end_step, logs);
}
}

var results = new Dictionary<string, float>();
foreach (var log in logs)
{
results[log.Key] = log.Value;
}
return results;
return evaluate(callbacks, data_handler, is_val);
} }


public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false) public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
@@ -107,31 +116,8 @@ namespace Tensorflow.Keras.Engine
Verbose = verbose, Verbose = verbose,
Steps = data_handler.Inferredsteps Steps = data_handler.Inferredsteps
}); });
callbacks.on_test_begin();


Dictionary<string, float> logs = null;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
if (is_val == false)
callbacks.on_test_batch_end(end_step, logs);
}
}

var results = new Dictionary<string, float>();
foreach (var log in logs)
{
results[log.Key] = log.Value;
}
return results;
return evaluate(callbacks, data_handler, is_val);
} }




@@ -150,51 +136,8 @@ namespace Tensorflow.Keras.Engine
Verbose = verbose, Verbose = verbose,
Steps = data_handler.Inferredsteps Steps = data_handler.Inferredsteps
}); });
callbacks.on_test_begin();

Dictionary<string, float> logs = null;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
if (is_val == false)
callbacks.on_test_batch_end(end_step, logs);
}
}

var results = new Dictionary<string, float>();
foreach (var log in logs)
{
results[log.Key] = log.Value;
}
return results;
}

Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}

Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);

compiled_metrics.update_state(y, y_pred);


return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2);
return evaluate(callbacks, data_handler, is_val);
} }
} }
}
}

Loading…
Cancel
Save