From 0cc25fbc35eb406c4f7e93ae9894633c03bfadae Mon Sep 17 00:00:00 2001 From: dogvane Date: Wed, 12 Jul 2023 17:00:16 +0800 Subject: [PATCH] =?UTF-8?q?Add=20a=20function=EF=BC=88get=5Fclassification?= =?UTF-8?q?=5Fstatistics=EF=BC=89=20to=20count=20the=20number=20of=20label?= =?UTF-8?q?=20categories=20for=20the=20image=5Fdataset=5Ffrom=5Fdirectory?= =?UTF-8?q?=20method.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...processing.image_dataset_from_directory.cs | 32 +++++++++++++++++++ ...eprocessing.paths_and_labels_to_dataset.cs | 1 + 2 files changed, 33 insertions(+) diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs index f42d12cd..377ac4de 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs @@ -8,6 +8,37 @@ namespace Tensorflow.Keras { public static string[] WHITELIST_FORMATS = new[] { ".bmp", ".gif", ".jpeg", ".jpg", ".png" }; + /// + /// Function that calculates the classification statistics for a given array of classified data. + /// The function takes an array of classified data as input and returns a dictionary containing the count and percentage of each class in the input array. + /// This function can be used to analyze the distribution of classes in a dataset or to evaluate the performance of a classification model. + /// + /// + /// code from copilot + /// + /// + /// + Dictionary get_classification_statistics(int[] label_ids, string[] label_class_names) + { + var countDict = label_ids.GroupBy(x => x) + .ToDictionary(g => g.Key, g => g.Count()); + var totalCount = label_ids.Length; + var ratioDict = label_class_names.ToDictionary(name => name, + name => + (double)(countDict.ContainsKey(Array.IndexOf(label_class_names, name)) + ? countDict[Array.IndexOf(label_class_names, name)] : 0) + / totalCount); + + print("Classification statistics:"); + foreach (string labelName in label_class_names) + { + double ratio = ratioDict[labelName]; + print($"{labelName}: {ratio * 100:F2}%"); + } + + return ratioDict; + } + /// /// Generates a `tf.data.Dataset` from image files in a directory. /// https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory @@ -53,6 +84,7 @@ namespace Tensorflow.Keras follow_links: follow_links); (image_paths, label_list) = keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset); + get_classification_statistics(label_list, class_name_list); var dataset = paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation); if (shuffle) diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs index eaa762d8..232f81eb 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs @@ -9,6 +9,7 @@ namespace Tensorflow.Keras /// /// 图片路径转为数据处理用的dataset + /// 通常用于预测时读取图片 /// /// ///