Browse Source

Fix model.evaluate don't have output

tags/v0.100.5-BERT-load
wangdapao666 Haiping 2 years ago
parent
commit
89fe0bb59a
5 changed files with 87 additions and 16 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
  2. +13
    -1
      src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
  3. +18
    -5
      src/TensorFlowNET.Keras/Callbacks/History.cs
  4. +26
    -5
      src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
  5. +28
    -5
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+ 2
- 0
src/TensorFlowNET.Core/Keras/Engine/ICallback.cs View File

@@ -12,4 +12,6 @@ public interface ICallback
void on_predict_batch_begin(long step);
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs);
void on_predict_end();
void on_test_begin();
void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs);
}

+ 13
- 1
src/TensorFlowNET.Keras/Callbacks/CallbackList.cs View File

@@ -20,7 +20,10 @@ public class CallbackList
{
callbacks.ForEach(x => x.on_train_begin());
}

public void on_test_begin()
{
callbacks.ForEach(x => x.on_test_begin());
}
public void on_epoch_begin(int epoch)
{
callbacks.ForEach(x => x.on_epoch_begin(epoch));
@@ -60,4 +63,13 @@ public class CallbackList
{
callbacks.ForEach(x => x.on_predict_end());
}

public void on_test_batch_begin(long step)
{
callbacks.ForEach(x => x.on_train_batch_begin(step));
}
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
{
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs));
}
}

+ 18
- 5
src/TensorFlowNET.Keras/Callbacks/History.cs View File

@@ -18,7 +18,11 @@ public class History : ICallback
epochs = new List<int>();
history = new Dictionary<string, List<float>>();
}

public void on_test_begin()
{
epochs = new List<int>();
history = new Dictionary<string, List<float>>();
}
public void on_epoch_begin(int epoch)
{

@@ -26,7 +30,7 @@ public class History : ICallback

public void on_train_batch_begin(long step)
{
}

public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
@@ -55,16 +59,25 @@ public class History : ICallback

public void on_predict_batch_begin(long step)
{
}

public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
{
}

public void on_predict_end()
{

}

public void on_test_batch_begin(long step)
{

}

public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
{
}
}

+ 26
- 5
src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs View File

@@ -22,7 +22,10 @@ namespace Tensorflow.Keras.Callbacks
_called_in_fit = true;
_sw = new Stopwatch();
}

public void on_test_begin()
{
_sw = new Stopwatch();
}
public void on_epoch_begin(int epoch)
{
_reset_progbar();
@@ -44,7 +47,7 @@ namespace Tensorflow.Keras.Callbacks
var progress = "";
var length = 30.0 / _parameters.Steps;
for (int i = 0; i < Math.Floor(end_step * length - 1); i++)
progress += "=";
progress += "=";
if (progress.Length < 28)
progress += ">";
else
@@ -84,17 +87,35 @@ namespace Tensorflow.Keras.Callbacks

public void on_predict_batch_begin(long step)
{
}

public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
{
}

public void on_predict_end()
{

}

public void on_test_batch_begin(long step)
{
_sw.Restart();
}
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
{
_sw.Stop();
var elapse = _sw.ElapsedMilliseconds;
var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}"));

Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}");
if (!Console.IsOutputRedirected)
{
Console.CursorLeft = 0;
}
}

}
}

+ 28
- 5
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -5,6 +5,10 @@ using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using static Tensorflow.Binding;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using Tensorflow;
using Tensorflow.Keras.Callbacks;

namespace Tensorflow.Keras.Engine
{
@@ -31,6 +35,11 @@ namespace Tensorflow.Keras.Engine
bool use_multiprocessing = false,
bool return_dict = false)
{
if (x.dims[0] != y.dims[0])
{
throw new InvalidArgumentError(
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
}
var data_handler = new DataHandler(new DataHandlerArgs
{
X = x,
@@ -46,18 +55,31 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Steps = data_handler.Inferredsteps
});
callbacks.on_test_begin();

foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
// callbacks.on_epoch_begin(epoch)
//callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
IEnumerable<(string, Tensor)> logs = null;

foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
results = test_function(data_handler, iterator);
callbacks.on_train_batch_begin(step);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
callbacks.on_test_batch_end(end_step, logs);
}
}
GC.Collect();
GC.WaitForPendingFinalizers();
}

public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
@@ -75,7 +97,8 @@ namespace Tensorflow.Keras.Engine
reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();


foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)


Loading…
Cancel
Save