Browse Source

Fix model.evaluate don't have output

tags/v0.100.5-BERT-load
wangdapao666 2 years ago
parent
commit
3ea81d23e6
5 changed files with 87 additions and 16 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
  2. +13
    -1
      src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
  3. +18
    -5
      src/TensorFlowNET.Keras/Callbacks/History.cs
  4. +26
    -5
      src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
  5. +28
    -5
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+ 2
- 0
src/TensorFlowNET.Core/Keras/Engine/ICallback.cs View File

@@ -12,4 +12,6 @@ public interface ICallback
void on_predict_batch_begin(long step); void on_predict_batch_begin(long step);
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs);
void on_predict_end(); void on_predict_end();
void on_test_begin();
void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs);
} }

+ 13
- 1
src/TensorFlowNET.Keras/Callbacks/CallbackList.cs View File

@@ -20,7 +20,10 @@ public class CallbackList
{ {
callbacks.ForEach(x => x.on_train_begin()); callbacks.ForEach(x => x.on_train_begin());
} }

public void on_test_begin()
{
callbacks.ForEach(x => x.on_test_begin());
}
public void on_epoch_begin(int epoch) public void on_epoch_begin(int epoch)
{ {
callbacks.ForEach(x => x.on_epoch_begin(epoch)); callbacks.ForEach(x => x.on_epoch_begin(epoch));
@@ -60,4 +63,13 @@ public class CallbackList
{ {
callbacks.ForEach(x => x.on_predict_end()); callbacks.ForEach(x => x.on_predict_end());
} }

public void on_test_batch_begin(long step)
{
callbacks.ForEach(x => x.on_train_batch_begin(step));
}
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
{
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs));
}
} }

+ 18
- 5
src/TensorFlowNET.Keras/Callbacks/History.cs View File

@@ -18,7 +18,11 @@ public class History : ICallback
epochs = new List<int>(); epochs = new List<int>();
history = new Dictionary<string, List<float>>(); history = new Dictionary<string, List<float>>();
} }

public void on_test_begin()
{
epochs = new List<int>();
history = new Dictionary<string, List<float>>();
}
public void on_epoch_begin(int epoch) public void on_epoch_begin(int epoch)
{ {


@@ -26,7 +30,7 @@ public class History : ICallback


public void on_train_batch_begin(long step) public void on_train_batch_begin(long step)
{ {
} }


public void on_train_batch_end(long end_step, Dictionary<string, float> logs) public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
@@ -55,16 +59,25 @@ public class History : ICallback


public void on_predict_batch_begin(long step) public void on_predict_batch_begin(long step)
{ {
} }


public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
{ {
} }


public void on_predict_end() public void on_predict_end()
{ {

}

public void on_test_batch_begin(long step)
{

}

public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
{
} }
} }

+ 26
- 5
src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs View File

@@ -22,7 +22,10 @@ namespace Tensorflow.Keras.Callbacks
_called_in_fit = true; _called_in_fit = true;
_sw = new Stopwatch(); _sw = new Stopwatch();
} }

public void on_test_begin()
{
_sw = new Stopwatch();
}
public void on_epoch_begin(int epoch) public void on_epoch_begin(int epoch)
{ {
_reset_progbar(); _reset_progbar();
@@ -44,7 +47,7 @@ namespace Tensorflow.Keras.Callbacks
var progress = ""; var progress = "";
var length = 30.0 / _parameters.Steps; var length = 30.0 / _parameters.Steps;
for (int i = 0; i < Math.Floor(end_step * length - 1); i++) for (int i = 0; i < Math.Floor(end_step * length - 1); i++)
progress += "=";
progress += "=";
if (progress.Length < 28) if (progress.Length < 28)
progress += ">"; progress += ">";
else else
@@ -84,17 +87,35 @@ namespace Tensorflow.Keras.Callbacks


public void on_predict_batch_begin(long step) public void on_predict_batch_begin(long step)
{ {
} }


public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
{ {
} }


public void on_predict_end() public void on_predict_end()
{ {

}

public void on_test_batch_begin(long step)
{
_sw.Restart();
} }
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
{
_sw.Stop();
var elapse = _sw.ElapsedMilliseconds;
var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}"));

Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}");
if (!Console.IsOutputRedirected)
{
Console.CursorLeft = 0;
}
}

} }
} }

+ 28
- 5
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -5,6 +5,10 @@ using System.Linq;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Engine.DataAdapters;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using Tensorflow;
using Tensorflow.Keras.Callbacks;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
@@ -31,6 +35,11 @@ namespace Tensorflow.Keras.Engine
bool use_multiprocessing = false, bool use_multiprocessing = false,
bool return_dict = false) bool return_dict = false)
{ {
if (x.dims[0] != y.dims[0])
{
throw new InvalidArgumentError(
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
}
var data_handler = new DataHandler(new DataHandlerArgs var data_handler = new DataHandler(new DataHandlerArgs
{ {
X = x, X = x,
@@ -46,18 +55,31 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution StepsPerExecution = _steps_per_execution
}); });


var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Steps = data_handler.Inferredsteps
});
callbacks.on_test_begin();

foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{ {
reset_metrics(); reset_metrics();
// callbacks.on_epoch_begin(epoch)
//callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration(); // data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
IEnumerable<(string, Tensor)> logs = null;

foreach (var step in data_handler.steps()) foreach (var step in data_handler.steps())
{ {
// callbacks.on_train_batch_begin(step)
results = test_function(data_handler, iterator);
callbacks.on_train_batch_begin(step);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
callbacks.on_test_batch_end(end_step, logs);
} }
} }
GC.Collect();
GC.WaitForPendingFinalizers();
} }


public KeyValuePair<string, float>[] evaluate(IDatasetV2 x) public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
@@ -75,7 +97,8 @@ namespace Tensorflow.Keras.Engine
reset_metrics(); reset_metrics();
// callbacks.on_epoch_begin(epoch) // callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration(); // data_handler.catch_stop_iteration();


foreach (var step in data_handler.steps()) foreach (var step in data_handler.steps())
{ {
// callbacks.on_train_batch_begin(step) // callbacks.on_train_batch_begin(step)


Loading…
Cancel
Save