From 53bd70bed3828a81e83bc1a2edbe1b3cbfab197a Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 7 Nov 2023 22:54:08 +0800 Subject: [PATCH 1/3] fix: fix the validation_pack when multiple input --- src/TensorFlowNET.Core/Util/Data.cs | 26 ++++++++++++++----- .../Engine/DataAdapters/DataAdapter.cs | 14 +++++++--- .../Engine/Model.Evaluate.cs | 8 +++++- src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 23 +++++++++++++--- 4 files changed, 56 insertions(+), 15 deletions(-) diff --git a/src/TensorFlowNET.Core/Util/Data.cs b/src/TensorFlowNET.Core/Util/Data.cs index a14c69b1..4e5a6543 100644 --- a/src/TensorFlowNET.Core/Util/Data.cs +++ b/src/TensorFlowNET.Core/Util/Data.cs @@ -1,4 +1,5 @@ -using Tensorflow.NumPy; +using OneOf; +using Tensorflow.NumPy; namespace Tensorflow.Util { @@ -8,10 +9,10 @@ namespace Tensorflow.Util /// public class ValidationDataPack { - public NDArray val_x; + public OneOf val_x; public NDArray val_y; public NDArray val_sample_weight = null; - + public bool val_x_is_array = false; public ValidationDataPack((NDArray, NDArray) validation_data) { this.val_x = validation_data.Item1; @@ -27,15 +28,17 @@ namespace Tensorflow.Util public ValidationDataPack((IEnumerable, NDArray) validation_data) { - this.val_x = validation_data.Item1.ToArray()[0]; + this.val_x = validation_data.Item1.ToArray(); this.val_y = validation_data.Item2; + val_x_is_array = true; } public ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) { - this.val_x = validation_data.Item1.ToArray()[0]; + this.val_x = validation_data.Item1.ToArray(); this.val_y = validation_data.Item2; this.val_sample_weight = validation_data.Item3; + val_x_is_array = true; } public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) @@ -52,15 +55,24 @@ namespace Tensorflow.Util public void Deconstruct(out NDArray val_x, out NDArray val_y) { - val_x = this.val_x; + val_x = this.val_x.AsT0; val_y = this.val_y; } public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) { - val_x = this.val_x; + val_x = this.val_x.AsT0; + val_y = this.val_y; + val_sample_weight = this.val_sample_weight; + } + + // add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) + public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse) + { + val_x_array = this.val_x.AsT1; val_y = this.val_y; val_sample_weight = this.val_sample_weight; + unuse = null; } } } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs index b2750496..590f30a7 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs @@ -92,9 +92,17 @@ namespace Tensorflow.Keras.Engine.DataAdapters var train_y = y[new Slice(0, train_count)]; var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); var val_y = y[new Slice(train_count)]; - NDArray tmp_sample_weight = sample_weight; - sample_weight = sample_weight[new Slice(0, train_count)]; - ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); + + ValidationDataPack validation_data; + if (sample_weight != null) + { + validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]); + sample_weight = sample_weight[new Slice(0, train_count)]; + } + else + { + validation_data = (val_x, val_y); + } return ((train_x, train_y, sample_weight), validation_data); } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 474d5e5a..b3264429 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -70,13 +70,19 @@ namespace Tensorflow.Keras.Engine return evaluate(data_handler, callbacks, is_val, test_function); } - public Dictionary evaluate(IEnumerable x, Tensor y, int verbose = 1, bool is_val = false) + public Dictionary evaluate( + IEnumerable x, + Tensor y, + int verbose = 1, + NDArray sample_weight = null, + bool is_val = false) { var data_handler = new DataHandler(new DataHandlerArgs { X = new Tensors(x.ToArray()), Y = y, Model = this, + SampleWeight = sample_weight, StepsPerExecution = _steps_per_execution }); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index d61211c7..13a1b63b 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -7,6 +7,7 @@ using Tensorflow.Keras.Engine.DataAdapters; using System.Diagnostics; using Tensorflow.Keras.Callbacks; using Tensorflow.Util; +using OneOf; namespace Tensorflow.Keras.Engine { @@ -287,10 +288,24 @@ namespace Tensorflow.Keras.Engine if (validation_data != null) { - // Because evaluate calls call_test_batch_end, this interferes with our output on the screen - // so we need to pass a is_val parameter to stop on_test_batch_end - var (val_x, val_y, val_sample_weight) = validation_data; - var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true); + NDArray val_x; + NDArray[] val_x_array; + NDArray val_y; + NDArray val_sample_weight; + Dictionary val_logs; + if (!validation_data.val_x_is_array) + { + (val_x, val_y, val_sample_weight) = validation_data; + // Because evaluate calls call_test_batch_end, this interferes with our output on the screen + // so we need to pass a is_val parameter to stop on_test_batch_end + val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true); + + } + else + { + (val_x_array, val_y, val_sample_weight, _) = validation_data; + val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true); + } foreach (var log in val_logs) { logs["val_" + log.Key] = log.Value; From d453fb6611f4acb3ab405579ae804279d6e07cbe Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 7 Nov 2023 23:34:37 +0800 Subject: [PATCH 2/3] refactor: declare some field of ValidationPack as internal --- src/TensorFlowNET.Core/Util/Data.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/TensorFlowNET.Core/Util/Data.cs b/src/TensorFlowNET.Core/Util/Data.cs index 4e5a6543..388efc50 100644 --- a/src/TensorFlowNET.Core/Util/Data.cs +++ b/src/TensorFlowNET.Core/Util/Data.cs @@ -9,9 +9,9 @@ namespace Tensorflow.Util /// public class ValidationDataPack { - public OneOf val_x; - public NDArray val_y; - public NDArray val_sample_weight = null; + internal OneOf val_x; + internal NDArray val_y; + internal NDArray val_sample_weight = null; public bool val_x_is_array = false; public ValidationDataPack((NDArray, NDArray) validation_data) { @@ -33,7 +33,7 @@ namespace Tensorflow.Util val_x_is_array = true; } - public ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) + internal ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) { this.val_x = validation_data.Item1.ToArray(); this.val_y = validation_data.Item2; From 47e9019a187744bf31e315525ffe352dad36a00c Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 7 Nov 2023 23:36:15 +0800 Subject: [PATCH 3/3] refactor: fix a typo --- src/TensorFlowNET.Core/Util/Data.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Util/Data.cs b/src/TensorFlowNET.Core/Util/Data.cs index 388efc50..fe3466ed 100644 --- a/src/TensorFlowNET.Core/Util/Data.cs +++ b/src/TensorFlowNET.Core/Util/Data.cs @@ -33,7 +33,7 @@ namespace Tensorflow.Util val_x_is_array = true; } - internal ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) + public ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) { this.val_x = validation_data.Item1.ToArray(); this.val_y = validation_data.Item2;