|
|
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Engine |
|
|
|
/// <param name="use_multiprocessing"></param> |
|
|
|
/// <param name="return_dict"></param> |
|
|
|
/// <param name="is_val"></param> |
|
|
|
public Dictionary<string, float> evaluate(Tensor x, Tensor y, |
|
|
|
public Dictionary<string, float> evaluate(NDArray x, NDArray y, |
|
|
|
int batch_size = -1, |
|
|
|
int verbose = 1, |
|
|
|
int steps = -1, |
|
|
@@ -115,62 +115,53 @@ namespace Tensorflow.Keras.Engine |
|
|
|
/// <param name="test_func">The function to be called on each batch of data.</param> |
|
|
|
/// <param name="is_val">Whether it is validation or test.</param> |
|
|
|
/// <returns></returns> |
|
|
|
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, Tensor[], Dictionary<string, float>> test_func) |
|
|
|
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, OwnedIterator, Dictionary<string, float>> test_func) |
|
|
|
{ |
|
|
|
callbacks.on_test_begin(); |
|
|
|
|
|
|
|
var results = new Dictionary<string, float>(); |
|
|
|
var logs = results; |
|
|
|
var logs = new Dictionary<string, float>(); |
|
|
|
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) |
|
|
|
{ |
|
|
|
reset_metrics(); |
|
|
|
callbacks.on_epoch_begin(epoch); |
|
|
|
// data_handler.catch_stop_iteration(); |
|
|
|
|
|
|
|
foreach (var step in data_handler.steps()) |
|
|
|
{ |
|
|
|
callbacks.on_test_batch_begin(step); |
|
|
|
|
|
|
|
logs = test_func(data_handler, iterator.next()); |
|
|
|
|
|
|
|
tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _train_counter.assign_add(1)); |
|
|
|
|
|
|
|
logs = test_func(data_handler, iterator); |
|
|
|
var end_step = step + data_handler.StepIncrement; |
|
|
|
if (!is_val) |
|
|
|
callbacks.on_test_batch_end(end_step, logs); |
|
|
|
} |
|
|
|
|
|
|
|
if (!is_val) |
|
|
|
callbacks.on_epoch_end(epoch, logs); |
|
|
|
} |
|
|
|
|
|
|
|
foreach (var log in logs) |
|
|
|
{ |
|
|
|
results[log.Key] = log.Value; |
|
|
|
} |
|
|
|
|
|
|
|
callbacks.on_test_end(logs); |
|
|
|
var results = new Dictionary<string, float>(logs); |
|
|
|
return results; |
|
|
|
} |
|
|
|
|
|
|
|
Dictionary<string, float> test_function(DataHandler data_handler, Tensor[] data) |
|
|
|
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) |
|
|
|
{ |
|
|
|
var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]); |
|
|
|
|
|
|
|
var y_pred = Apply(x, training: false); |
|
|
|
var loss = compiled_loss.Call(y, y_pred); |
|
|
|
|
|
|
|
compiled_metrics.update_state(y, y_pred); |
|
|
|
|
|
|
|
var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2); |
|
|
|
var data = iterator.next(); |
|
|
|
var outputs = test_step(data_handler, data[0], data[1]); |
|
|
|
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data) |
|
|
|
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) |
|
|
|
{ |
|
|
|
var data = iterator.next(); |
|
|
|
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; |
|
|
|
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); |
|
|
|
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); |
|
|
|
var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray()); |
|
|
|
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) |
|
|
|
{ |
|
|
|
(x, y) = data_handler.DataAdapter.Expand1d(x, y); |
|
|
|
var y_pred = Apply(x, training: false); |
|
|
|
var loss = compiled_loss.Call(y, y_pred); |
|
|
|
compiled_metrics.update_state(y, y_pred); |
|
|
|
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); |
|
|
|
} |
|
|
|
} |
|
|
|
} |