Browse Source

Merge branch 'master' of https://github.com/SciSharp/TensorFlow.NET into alnovi/optimizer_tests

pull/1184/head
Alexander 1 year ago
parent
commit
b1972a84f4
4 changed files with 58 additions and 17 deletions
  1. +21
    -9
      src/TensorFlowNET.Core/Util/Data.cs
  2. +11
    -3
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
  3. +7
    -1
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  4. +19
    -4
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+ 21
- 9
src/TensorFlowNET.Core/Util/Data.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.NumPy;
using OneOf;
using Tensorflow.NumPy;


namespace Tensorflow.Util namespace Tensorflow.Util
{ {
@@ -8,10 +9,10 @@ namespace Tensorflow.Util
/// </summary> /// </summary>
public class ValidationDataPack public class ValidationDataPack
{ {
public NDArray val_x;
public NDArray val_y;
public NDArray val_sample_weight = null;
internal OneOf<NDArray, NDArray[]> val_x;
internal NDArray val_y;
internal NDArray val_sample_weight = null;
public bool val_x_is_array = false;
public ValidationDataPack((NDArray, NDArray) validation_data) public ValidationDataPack((NDArray, NDArray) validation_data)
{ {
this.val_x = validation_data.Item1; this.val_x = validation_data.Item1;
@@ -27,15 +28,17 @@ namespace Tensorflow.Util


public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
{ {
this.val_x = validation_data.Item1.ToArray()[0];
this.val_x = validation_data.Item1.ToArray();
this.val_y = validation_data.Item2; this.val_y = validation_data.Item2;
val_x_is_array = true;
} }


public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
{ {
this.val_x = validation_data.Item1.ToArray()[0];
this.val_x = validation_data.Item1.ToArray();
this.val_y = validation_data.Item2; this.val_y = validation_data.Item2;
this.val_sample_weight = validation_data.Item3; this.val_sample_weight = validation_data.Item3;
val_x_is_array = true;
} }


public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data)
@@ -52,15 +55,24 @@ namespace Tensorflow.Util


public void Deconstruct(out NDArray val_x, out NDArray val_y) public void Deconstruct(out NDArray val_x, out NDArray val_y)
{ {
val_x = this.val_x;
val_x = this.val_x.AsT0;
val_y = this.val_y; val_y = this.val_y;
} }


public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
{ {
val_x = this.val_x;
val_x = this.val_x.AsT0;
val_y = this.val_y;
val_sample_weight = this.val_sample_weight;
}

// add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse)
{
val_x_array = this.val_x.AsT1;
val_y = this.val_y; val_y = this.val_y;
val_sample_weight = this.val_sample_weight; val_sample_weight = this.val_sample_weight;
unuse = null;
} }
} }
} }

+ 11
- 3
src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs View File

@@ -92,9 +92,17 @@ namespace Tensorflow.Keras.Engine.DataAdapters
var train_y = y[new Slice(0, train_count)]; var train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
var val_y = y[new Slice(train_count)]; var val_y = y[new Slice(train_count)];
NDArray tmp_sample_weight = sample_weight;
sample_weight = sample_weight[new Slice(0, train_count)];
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]);

ValidationDataPack validation_data;
if (sample_weight != null)
{
validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]);
sample_weight = sample_weight[new Slice(0, train_count)];
}
else
{
validation_data = (val_x, val_y);
}
return ((train_x, train_y, sample_weight), validation_data); return ((train_x, train_y, sample_weight), validation_data);
} }
} }


+ 7
- 1
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -70,13 +70,19 @@ namespace Tensorflow.Keras.Engine
return evaluate(data_handler, callbacks, is_val, test_function); return evaluate(data_handler, callbacks, is_val, test_function);
} }


public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
public Dictionary<string, float> evaluate(
IEnumerable<Tensor> x,
Tensor y,
int verbose = 1,
NDArray sample_weight = null,
bool is_val = false)
{ {
var data_handler = new DataHandler(new DataHandlerArgs var data_handler = new DataHandler(new DataHandlerArgs
{ {
X = new Tensors(x.ToArray()), X = new Tensors(x.ToArray()),
Y = y, Y = y,
Model = this, Model = this,
SampleWeight = sample_weight,
StepsPerExecution = _steps_per_execution StepsPerExecution = _steps_per_execution
}); });




+ 19
- 4
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -7,6 +7,7 @@ using Tensorflow.Keras.Engine.DataAdapters;
using System.Diagnostics; using System.Diagnostics;
using Tensorflow.Keras.Callbacks; using Tensorflow.Keras.Callbacks;
using Tensorflow.Util; using Tensorflow.Util;
using OneOf;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
@@ -287,10 +288,24 @@ namespace Tensorflow.Keras.Engine


if (validation_data != null) if (validation_data != null)
{ {
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
var (val_x, val_y, val_sample_weight) = validation_data;
var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true);
NDArray val_x;
NDArray[] val_x_array;
NDArray val_y;
NDArray val_sample_weight;
Dictionary<string, float> val_logs;
if (!validation_data.val_x_is_array)
{
(val_x, val_y, val_sample_weight) = validation_data;
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true);

}
else
{
(val_x_array, val_y, val_sample_weight, _) = validation_data;
val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true);
}
foreach (var log in val_logs) foreach (var log in val_logs)
{ {
logs["val_" + log.Key] = log.Value; logs["val_" + log.Key] = log.Value;


Loading…
Cancel
Save