@@ -1,4 +1,5 @@ | |||||
using Tensorflow.NumPy; | |||||
using OneOf; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Util | namespace Tensorflow.Util | ||||
{ | { | ||||
@@ -8,10 +9,10 @@ namespace Tensorflow.Util | |||||
/// </summary> | /// </summary> | ||||
public class ValidationDataPack | public class ValidationDataPack | ||||
{ | { | ||||
public NDArray val_x; | |||||
public NDArray val_y; | |||||
public NDArray val_sample_weight = null; | |||||
internal OneOf<NDArray, NDArray[]> val_x; | |||||
internal NDArray val_y; | |||||
internal NDArray val_sample_weight = null; | |||||
public bool val_x_is_array = false; | |||||
public ValidationDataPack((NDArray, NDArray) validation_data) | public ValidationDataPack((NDArray, NDArray) validation_data) | ||||
{ | { | ||||
this.val_x = validation_data.Item1; | this.val_x = validation_data.Item1; | ||||
@@ -27,15 +28,17 @@ namespace Tensorflow.Util | |||||
public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | ||||
{ | { | ||||
this.val_x = validation_data.Item1.ToArray()[0]; | |||||
this.val_x = validation_data.Item1.ToArray(); | |||||
this.val_y = validation_data.Item2; | this.val_y = validation_data.Item2; | ||||
val_x_is_array = true; | |||||
} | } | ||||
public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | ||||
{ | { | ||||
this.val_x = validation_data.Item1.ToArray()[0]; | |||||
this.val_x = validation_data.Item1.ToArray(); | |||||
this.val_y = validation_data.Item2; | this.val_y = validation_data.Item2; | ||||
this.val_sample_weight = validation_data.Item3; | this.val_sample_weight = validation_data.Item3; | ||||
val_x_is_array = true; | |||||
} | } | ||||
public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | ||||
@@ -52,15 +55,24 @@ namespace Tensorflow.Util | |||||
public void Deconstruct(out NDArray val_x, out NDArray val_y) | public void Deconstruct(out NDArray val_x, out NDArray val_y) | ||||
{ | { | ||||
val_x = this.val_x; | |||||
val_x = this.val_x.AsT0; | |||||
val_y = this.val_y; | val_y = this.val_y; | ||||
} | } | ||||
public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | ||||
{ | { | ||||
val_x = this.val_x; | |||||
val_x = this.val_x.AsT0; | |||||
val_y = this.val_y; | |||||
val_sample_weight = this.val_sample_weight; | |||||
} | |||||
// add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||||
public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse) | |||||
{ | |||||
val_x_array = this.val_x.AsT1; | |||||
val_y = this.val_y; | val_y = this.val_y; | ||||
val_sample_weight = this.val_sample_weight; | val_sample_weight = this.val_sample_weight; | ||||
unuse = null; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -92,9 +92,17 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
var train_y = y[new Slice(0, train_count)]; | var train_y = y[new Slice(0, train_count)]; | ||||
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | ||||
var val_y = y[new Slice(train_count)]; | var val_y = y[new Slice(train_count)]; | ||||
NDArray tmp_sample_weight = sample_weight; | |||||
sample_weight = sample_weight[new Slice(0, train_count)]; | |||||
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); | |||||
ValidationDataPack validation_data; | |||||
if (sample_weight != null) | |||||
{ | |||||
validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]); | |||||
sample_weight = sample_weight[new Slice(0, train_count)]; | |||||
} | |||||
else | |||||
{ | |||||
validation_data = (val_x, val_y); | |||||
} | |||||
return ((train_x, train_y, sample_weight), validation_data); | return ((train_x, train_y, sample_weight), validation_data); | ||||
} | } | ||||
} | } | ||||
@@ -70,13 +70,19 @@ namespace Tensorflow.Keras.Engine | |||||
return evaluate(data_handler, callbacks, is_val, test_function); | return evaluate(data_handler, callbacks, is_val, test_function); | ||||
} | } | ||||
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false) | |||||
public Dictionary<string, float> evaluate( | |||||
IEnumerable<Tensor> x, | |||||
Tensor y, | |||||
int verbose = 1, | |||||
NDArray sample_weight = null, | |||||
bool is_val = false) | |||||
{ | { | ||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
X = new Tensors(x.ToArray()), | X = new Tensors(x.ToArray()), | ||||
Y = y, | Y = y, | ||||
Model = this, | Model = this, | ||||
SampleWeight = sample_weight, | |||||
StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
}); | }); | ||||
@@ -7,6 +7,7 @@ using Tensorflow.Keras.Engine.DataAdapters; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using Tensorflow.Keras.Callbacks; | using Tensorflow.Keras.Callbacks; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using OneOf; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -287,10 +288,24 @@ namespace Tensorflow.Keras.Engine | |||||
if (validation_data != null) | if (validation_data != null) | ||||
{ | { | ||||
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||||
// so we need to pass a is_val parameter to stop on_test_batch_end | |||||
var (val_x, val_y, val_sample_weight) = validation_data; | |||||
var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true); | |||||
NDArray val_x; | |||||
NDArray[] val_x_array; | |||||
NDArray val_y; | |||||
NDArray val_sample_weight; | |||||
Dictionary<string, float> val_logs; | |||||
if (!validation_data.val_x_is_array) | |||||
{ | |||||
(val_x, val_y, val_sample_weight) = validation_data; | |||||
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||||
// so we need to pass a is_val parameter to stop on_test_batch_end | |||||
val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true); | |||||
} | |||||
else | |||||
{ | |||||
(val_x_array, val_y, val_sample_weight, _) = validation_data; | |||||
val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true); | |||||
} | |||||
foreach (var log in val_logs) | foreach (var log in val_logs) | ||||
{ | { | ||||
logs["val_" + log.Key] = log.Value; | logs["val_" + log.Key] = log.Value; | ||||