Browse Source

Merge pull request #1123 from Wanglongzhi2001/master

fix: fix the bug of repeated progress bar in Model.fit()
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
86b235f9ff
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 45 additions and 35 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/IModel.cs
  3. +5
    -0
      src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
  4. +4
    -0
      src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
  5. +4
    -0
      src/TensorFlowNET.Keras/Callbacks/History.cs
  6. +3
    -0
      src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
  7. +24
    -33
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  8. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+ 3
- 0
src/TensorFlowNET.Core/Keras/Engine/ICallback.cs View File

@@ -14,6 +14,9 @@ public interface ICallback
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs);
void on_predict_end();
void on_test_begin();
void on_test_end(Dictionary<string, float> logs);
void on_test_batch_begin(long step);
void on_test_batch_end(long end_step, Dictionary<string, float> logs);


}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/IModel.cs View File

@@ -60,7 +60,7 @@ public interface IModel : ILayer
bool skip_mismatch = false,
object options = null);

Dictionary<string, float> evaluate(Tensor x, Tensor y,
Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
int steps = -1,


+ 5
- 0
src/TensorFlowNET.Keras/Callbacks/CallbackList.cs View File

@@ -73,4 +73,9 @@ public class CallbackList
{
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs));
}

public void on_test_end(Dictionary<string, float> logs)
{
callbacks.ForEach(x => x.on_test_end(logs));
}
}

+ 4
- 0
src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs View File

@@ -150,4 +150,8 @@ public class EarlyStopping: ICallback
return less_op;
}
}

public void on_test_end(Dictionary<string, float> logs)
{
}
}

+ 4
- 0
src/TensorFlowNET.Keras/Callbacks/History.cs View File

@@ -81,4 +81,8 @@ public class History : ICallback
public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
{
}

public void on_test_end(Dictionary<string, float> logs)
{
}
}

+ 3
- 0
src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs View File

@@ -118,5 +118,8 @@ namespace Tensorflow.Keras.Callbacks
}
}

public void on_test_end(Dictionary<string, float> logs)
{
}
}
}

+ 24
- 33
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="use_multiprocessing"></param>
/// <param name="return_dict"></param>
/// <param name="is_val"></param>
public Dictionary<string, float> evaluate(Tensor x, Tensor y,
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
int steps = -1,
@@ -115,62 +115,53 @@ namespace Tensorflow.Keras.Engine
/// <param name="test_func">The function to be called on each batch of data.</param>
/// <param name="is_val">Whether it is validation or test.</param>
/// <returns></returns>
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, Tensor[], Dictionary<string, float>> test_func)
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, OwnedIterator, Dictionary<string, float>> test_func)
{
callbacks.on_test_begin();

var results = new Dictionary<string, float>();
var logs = results;
var logs = new Dictionary<string, float>();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);

logs = test_func(data_handler, iterator.next());

tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _train_counter.assign_add(1));

logs = test_func(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
if (!is_val)
callbacks.on_test_batch_end(end_step, logs);
}

if (!is_val)
callbacks.on_epoch_end(epoch, logs);
}

foreach (var log in logs)
{
results[log.Key] = log.Value;
}

callbacks.on_test_end(logs);
var results = new Dictionary<string, float>(logs);
return results;
}

Dictionary<string, float> test_function(DataHandler data_handler, Tensor[] data)
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
{
var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]);

var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);

compiled_metrics.update_state(y, y_pred);

var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2);
var data = iterator.next();
var outputs = test_step(data_handler, data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}

Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data)
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray());
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}


Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -266,7 +266,7 @@ namespace Tensorflow.Keras.Engine
{
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
var val_logs = evaluate((Tensor)validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;


Loading…
Cancel
Save