fix: fix the validation_pack when multiple inputpull/1234/head
@@ -1,4 +1,5 @@ | |||
using Tensorflow.NumPy; | |||
using OneOf; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Util | |||
{ | |||
@@ -8,10 +9,10 @@ namespace Tensorflow.Util | |||
/// </summary> | |||
public class ValidationDataPack | |||
{ | |||
public NDArray val_x; | |||
public NDArray val_y; | |||
public NDArray val_sample_weight = null; | |||
internal OneOf<NDArray, NDArray[]> val_x; | |||
internal NDArray val_y; | |||
internal NDArray val_sample_weight = null; | |||
public bool val_x_is_array = false; | |||
public ValidationDataPack((NDArray, NDArray) validation_data) | |||
{ | |||
this.val_x = validation_data.Item1; | |||
@@ -27,15 +28,17 @@ namespace Tensorflow.Util | |||
public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||
{ | |||
this.val_x = validation_data.Item1.ToArray()[0]; | |||
this.val_x = validation_data.Item1.ToArray(); | |||
this.val_y = validation_data.Item2; | |||
val_x_is_array = true; | |||
} | |||
public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||
{ | |||
this.val_x = validation_data.Item1.ToArray()[0]; | |||
this.val_x = validation_data.Item1.ToArray(); | |||
this.val_y = validation_data.Item2; | |||
this.val_sample_weight = validation_data.Item3; | |||
val_x_is_array = true; | |||
} | |||
public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | |||
@@ -52,15 +55,24 @@ namespace Tensorflow.Util | |||
public void Deconstruct(out NDArray val_x, out NDArray val_y) | |||
{ | |||
val_x = this.val_x; | |||
val_x = this.val_x.AsT0; | |||
val_y = this.val_y; | |||
} | |||
public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||
{ | |||
val_x = this.val_x; | |||
val_x = this.val_x.AsT0; | |||
val_y = this.val_y; | |||
val_sample_weight = this.val_sample_weight; | |||
} | |||
// add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||
public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse) | |||
{ | |||
val_x_array = this.val_x.AsT1; | |||
val_y = this.val_y; | |||
val_sample_weight = this.val_sample_weight; | |||
unuse = null; | |||
} | |||
} | |||
} |
@@ -92,9 +92,17 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
var train_y = y[new Slice(0, train_count)]; | |||
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||
var val_y = y[new Slice(train_count)]; | |||
NDArray tmp_sample_weight = sample_weight; | |||
sample_weight = sample_weight[new Slice(0, train_count)]; | |||
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); | |||
ValidationDataPack validation_data; | |||
if (sample_weight != null) | |||
{ | |||
validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]); | |||
sample_weight = sample_weight[new Slice(0, train_count)]; | |||
} | |||
else | |||
{ | |||
validation_data = (val_x, val_y); | |||
} | |||
return ((train_x, train_y, sample_weight), validation_data); | |||
} | |||
} | |||
@@ -70,13 +70,19 @@ namespace Tensorflow.Keras.Engine | |||
return evaluate(data_handler, callbacks, is_val, test_function); | |||
} | |||
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false) | |||
public Dictionary<string, float> evaluate( | |||
IEnumerable<Tensor> x, | |||
Tensor y, | |||
int verbose = 1, | |||
NDArray sample_weight = null, | |||
bool is_val = false) | |||
{ | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = new Tensors(x.ToArray()), | |||
Y = y, | |||
Model = this, | |||
SampleWeight = sample_weight, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
@@ -7,6 +7,7 @@ using Tensorflow.Keras.Engine.DataAdapters; | |||
using System.Diagnostics; | |||
using Tensorflow.Keras.Callbacks; | |||
using Tensorflow.Util; | |||
using OneOf; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -287,10 +288,24 @@ namespace Tensorflow.Keras.Engine | |||
if (validation_data != null) | |||
{ | |||
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||
// so we need to pass a is_val parameter to stop on_test_batch_end | |||
var (val_x, val_y, val_sample_weight) = validation_data; | |||
var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true); | |||
NDArray val_x; | |||
NDArray[] val_x_array; | |||
NDArray val_y; | |||
NDArray val_sample_weight; | |||
Dictionary<string, float> val_logs; | |||
if (!validation_data.val_x_is_array) | |||
{ | |||
(val_x, val_y, val_sample_weight) = validation_data; | |||
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||
// so we need to pass a is_val parameter to stop on_test_batch_end | |||
val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true); | |||
} | |||
else | |||
{ | |||
(val_x_array, val_y, val_sample_weight, _) = validation_data; | |||
val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true); | |||
} | |||
foreach (var log in val_logs) | |||
{ | |||
logs["val_" + log.Key] = log.Value; | |||