@@ -12,4 +12,6 @@ public interface ICallback | |||||
void on_predict_batch_begin(long step); | void on_predict_batch_begin(long step); | ||||
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | ||||
void on_predict_end(); | void on_predict_end(); | ||||
void on_test_begin(); | |||||
void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs); | |||||
} | } |
@@ -20,7 +20,10 @@ public class CallbackList | |||||
{ | { | ||||
callbacks.ForEach(x => x.on_train_begin()); | callbacks.ForEach(x => x.on_train_begin()); | ||||
} | } | ||||
public void on_test_begin() | |||||
{ | |||||
callbacks.ForEach(x => x.on_test_begin()); | |||||
} | |||||
public void on_epoch_begin(int epoch) | public void on_epoch_begin(int epoch) | ||||
{ | { | ||||
callbacks.ForEach(x => x.on_epoch_begin(epoch)); | callbacks.ForEach(x => x.on_epoch_begin(epoch)); | ||||
@@ -60,4 +63,13 @@ public class CallbackList | |||||
{ | { | ||||
callbacks.ForEach(x => x.on_predict_end()); | callbacks.ForEach(x => x.on_predict_end()); | ||||
} | } | ||||
public void on_test_batch_begin(long step) | |||||
{ | |||||
callbacks.ForEach(x => x.on_train_batch_begin(step)); | |||||
} | |||||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
{ | |||||
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); | |||||
} | |||||
} | } |
@@ -18,7 +18,11 @@ public class History : ICallback | |||||
epochs = new List<int>(); | epochs = new List<int>(); | ||||
history = new Dictionary<string, List<float>>(); | history = new Dictionary<string, List<float>>(); | ||||
} | } | ||||
public void on_test_begin() | |||||
{ | |||||
epochs = new List<int>(); | |||||
history = new Dictionary<string, List<float>>(); | |||||
} | |||||
public void on_epoch_begin(int epoch) | public void on_epoch_begin(int epoch) | ||||
{ | { | ||||
@@ -26,7 +30,7 @@ public class History : ICallback | |||||
public void on_train_batch_begin(long step) | public void on_train_batch_begin(long step) | ||||
{ | { | ||||
} | } | ||||
public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | ||||
@@ -55,16 +59,25 @@ public class History : ICallback | |||||
public void on_predict_batch_begin(long step) | public void on_predict_batch_begin(long step) | ||||
{ | { | ||||
} | } | ||||
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | ||||
{ | { | ||||
} | } | ||||
public void on_predict_end() | public void on_predict_end() | ||||
{ | { | ||||
} | |||||
public void on_test_batch_begin(long step) | |||||
{ | |||||
} | |||||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
{ | |||||
} | } | ||||
} | } |
@@ -22,7 +22,10 @@ namespace Tensorflow.Keras.Callbacks | |||||
_called_in_fit = true; | _called_in_fit = true; | ||||
_sw = new Stopwatch(); | _sw = new Stopwatch(); | ||||
} | } | ||||
public void on_test_begin() | |||||
{ | |||||
_sw = new Stopwatch(); | |||||
} | |||||
public void on_epoch_begin(int epoch) | public void on_epoch_begin(int epoch) | ||||
{ | { | ||||
_reset_progbar(); | _reset_progbar(); | ||||
@@ -44,7 +47,7 @@ namespace Tensorflow.Keras.Callbacks | |||||
var progress = ""; | var progress = ""; | ||||
var length = 30.0 / _parameters.Steps; | var length = 30.0 / _parameters.Steps; | ||||
for (int i = 0; i < Math.Floor(end_step * length - 1); i++) | for (int i = 0; i < Math.Floor(end_step * length - 1); i++) | ||||
progress += "="; | |||||
progress += "="; | |||||
if (progress.Length < 28) | if (progress.Length < 28) | ||||
progress += ">"; | progress += ">"; | ||||
else | else | ||||
@@ -84,17 +87,35 @@ namespace Tensorflow.Keras.Callbacks | |||||
public void on_predict_batch_begin(long step) | public void on_predict_batch_begin(long step) | ||||
{ | { | ||||
} | } | ||||
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | ||||
{ | { | ||||
} | } | ||||
public void on_predict_end() | public void on_predict_end() | ||||
{ | { | ||||
} | |||||
public void on_test_batch_begin(long step) | |||||
{ | |||||
_sw.Restart(); | |||||
} | } | ||||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
{ | |||||
_sw.Stop(); | |||||
var elapse = _sw.ElapsedMilliseconds; | |||||
var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}")); | |||||
Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}"); | |||||
if (!Console.IsOutputRedirected) | |||||
{ | |||||
Console.CursorLeft = 0; | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -5,6 +5,10 @@ using System.Linq; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Utils; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Callbacks; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -31,6 +35,11 @@ namespace Tensorflow.Keras.Engine | |||||
bool use_multiprocessing = false, | bool use_multiprocessing = false, | ||||
bool return_dict = false) | bool return_dict = false) | ||||
{ | { | ||||
if (x.dims[0] != y.dims[0]) | |||||
{ | |||||
throw new InvalidArgumentError( | |||||
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | |||||
} | |||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
X = x, | X = x, | ||||
@@ -46,18 +55,31 @@ namespace Tensorflow.Keras.Engine | |||||
StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
}); | }); | ||||
var callbacks = new CallbackList(new CallbackParams | |||||
{ | |||||
Model = this, | |||||
Verbose = verbose, | |||||
Steps = data_handler.Inferredsteps | |||||
}); | |||||
callbacks.on_test_begin(); | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
{ | { | ||||
reset_metrics(); | reset_metrics(); | ||||
// callbacks.on_epoch_begin(epoch) | |||||
//callbacks.on_epoch_begin(epoch); | |||||
// data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
IEnumerable<(string, Tensor)> results = null; | |||||
IEnumerable<(string, Tensor)> logs = null; | |||||
foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
{ | { | ||||
// callbacks.on_train_batch_begin(step) | |||||
results = test_function(data_handler, iterator); | |||||
callbacks.on_train_batch_begin(step); | |||||
logs = test_function(data_handler, iterator); | |||||
var end_step = step + data_handler.StepIncrement; | |||||
callbacks.on_test_batch_end(end_step, logs); | |||||
} | } | ||||
} | } | ||||
GC.Collect(); | |||||
GC.WaitForPendingFinalizers(); | |||||
} | } | ||||
public KeyValuePair<string, float>[] evaluate(IDatasetV2 x) | public KeyValuePair<string, float>[] evaluate(IDatasetV2 x) | ||||
@@ -75,7 +97,8 @@ namespace Tensorflow.Keras.Engine | |||||
reset_metrics(); | reset_metrics(); | ||||
// callbacks.on_epoch_begin(epoch) | // callbacks.on_epoch_begin(epoch) | ||||
// data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
{ | { | ||||
// callbacks.on_train_batch_begin(step) | // callbacks.on_train_batch_begin(step) | ||||