Browse Source

Add a function(get_classification_statistics) to count the number of label categories for the image_dataset_from_directory method.

tags/v0.110.4-Transformer-Model
dogvane 2 years ago
parent
commit
0cc25fbc35
2 changed files with 33 additions and 0 deletions
  1. +32
    -0
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
  2. +1
    -0
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs

+ 32
- 0
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs View File

@@ -8,6 +8,37 @@ namespace Tensorflow.Keras
{ {
public static string[] WHITELIST_FORMATS = new[] { ".bmp", ".gif", ".jpeg", ".jpg", ".png" }; public static string[] WHITELIST_FORMATS = new[] { ".bmp", ".gif", ".jpeg", ".jpg", ".png" };


/// <summary>
/// Function that calculates the classification statistics for a given array of classified data.
/// The function takes an array of classified data as input and returns a dictionary containing the count and percentage of each class in the input array.
/// This function can be used to analyze the distribution of classes in a dataset or to evaluate the performance of a classification model.
/// </summary>
/// <remarks>
/// code from copilot
/// </remarks>
/// <param name="label_ids"></param>
/// <param name="label_class_names"></param>
Dictionary<string, double> get_classification_statistics(int[] label_ids, string[] label_class_names)
{
var countDict = label_ids.GroupBy(x => x)
.ToDictionary(g => g.Key, g => g.Count());
var totalCount = label_ids.Length;
var ratioDict = label_class_names.ToDictionary(name => name,
name =>
(double)(countDict.ContainsKey(Array.IndexOf(label_class_names, name))
? countDict[Array.IndexOf(label_class_names, name)] : 0)
/ totalCount);

print("Classification statistics:");
foreach (string labelName in label_class_names)
{
double ratio = ratioDict[labelName];
print($"{labelName}: {ratio * 100:F2}%");
}

return ratioDict;
}

/// <summary> /// <summary>
/// Generates a `tf.data.Dataset` from image files in a directory. /// Generates a `tf.data.Dataset` from image files in a directory.
/// https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory /// https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory
@@ -53,6 +84,7 @@ namespace Tensorflow.Keras
follow_links: follow_links); follow_links: follow_links);


(image_paths, label_list) = keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset); (image_paths, label_list) = keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset);
get_classification_statistics(label_list, class_name_list);


var dataset = paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation); var dataset = paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation);
if (shuffle) if (shuffle)


+ 1
- 0
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs View File

@@ -9,6 +9,7 @@ namespace Tensorflow.Keras


/// <summary> /// <summary>
/// 图片路径转为数据处理用的dataset /// 图片路径转为数据处理用的dataset
/// 通常用于预测时读取图片
/// </summary> /// </summary>
/// <param name="image_paths"></param> /// <param name="image_paths"></param>
/// <param name="image_size"></param> /// <param name="image_size"></param>


Loading…
Cancel
Save