using NumSharp; using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Text; using Tensorflow; using TensorFlowNET.Examples.Utility; using static Tensorflow.Python; namespace TensorFlowNET.Examples.ImageProcess { /// /// In this tutorial, we will reuse the feature extraction capabilities from powerful image classifiers trained on ImageNet /// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this /// by taking a piece of a model that has already been trained on a related task and reusing it in a new model. /// /// https://www.tensorflow.org/hub/tutorials/image_retraining /// public class RetrainImageClassifier : IExample { public int Priority => 16; public bool Enabled { get; set; } = false; public bool ImportGraph { get; set; } = true; public string Name => "Retrain Image Classifier"; const string data_dir = "retrain_images"; string summaries_dir = Path.Join(data_dir, "retrain_logs"); string image_dir = Path.Join(data_dir, "flower_photos"); string bottleneck_dir = Path.Join(data_dir, "bottleneck"); string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3"; float testing_percentage = 0.1f; float validation_percentage = 0.1f; Tensor resized_image_tensor; Dictionary> image_lists; public bool Run() { PrepareData(); var graph = tf.Graph().as_default(); tf.train.import_meta_graph("graph/InceptionV3.meta"); Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze"); Tensor resized_image_tensor = graph.OperationByName("Placeholder"); Tensor final_tensor = graph.OperationByName("final_result"); Tensor ground_truth_input = graph.OperationByName("input/GroundTruthInput"); var sw = new Stopwatch(); with(tf.Session(graph), sess => { // Initialize all weights: for the module to their pretrained values, // and for the newly added retraining layer to random initial values. var init = tf.global_variables_initializer(); sess.run(init); var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding(); // We'll make sure we've calculated the 'bottleneck' image summaries and // cached them on disk. cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, resized_image_tensor, bottleneck_tensor, tfhub_module); // Create the operations we need to evaluate the accuracy of our new layer. var (evaluation_step, _) = add_evaluation_step(final_tensor, ground_truth_input); // Merge all the summaries and write them out to the summaries_dir var merged = tf.summary.merge_all(); var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph); }); return false; } /// /// Inserts the operations we need to evaluate the accuracy of our results. /// /// /// /// private (Tensor, Tensor) add_evaluation_step(Tensor result_tensor, Tensor ground_truth_tensor) { Tensor evaluation_step = null, correct_prediction = null, prediction = null; with(tf.name_scope("accuracy"), scope => { with(tf.name_scope("correct_prediction"), delegate { prediction = tf.argmax(result_tensor, 1); correct_prediction = tf.equal(prediction, ground_truth_tensor); }); with(tf.name_scope("accuracy"), delegate { evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); }); }); tf.summary.scalar("accuracy", evaluation_step); return (evaluation_step, prediction); } /// /// Ensures all the training, testing, and validation bottlenecks are cached. /// /// /// /// /// /// /// /// /// /// private void cache_bottlenecks(Session sess, Dictionary> image_lists, string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) { int how_many_bottlenecks = 0; foreach(var (label_name, label_lists) in image_lists) { foreach(var category in new string[] { "training", "testing", "validation" }) { var category_list = label_lists[category]; foreach(var (index, unused_base_name) in enumerate(category_list)) { get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category, bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, resized_input_tensor, bottleneck_tensor, module_name); how_many_bottlenecks++; if (how_many_bottlenecks % 100 == 0) print($"{how_many_bottlenecks} bottleneck files created."); } } } } private float[] get_or_create_bottleneck(Session sess, Dictionary> image_lists, string label_name, int index, string image_dir, string category, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) { var label_lists = image_lists[label_name]; var sub_dir_path = Path.Join(bottleneck_dir, label_name); Directory.CreateDirectory(sub_dir_path); string bottleneck_path = get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, category, module_name); if (!File.Exists(bottleneck_path)) create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, decoded_image_tensor, resized_input_tensor, bottleneck_tensor); var bottleneck_string = File.ReadAllText(bottleneck_path); var bottleneck_values = Array.ConvertAll(bottleneck_string.Split(','), x => float.Parse(x)); return bottleneck_values; } private void create_bottleneck_file(string bottleneck_path, Dictionary> image_lists, string label_name, int index, string image_dir, string category, Session sess, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor) { // Create a single bottleneck file. print("Creating bottleneck at " + bottleneck_path); var image_path = get_image_path(image_lists, label_name, index, image_dir, category); if (!File.Exists(image_path)) print($"File does not exist {image_path}"); var image_data = File.ReadAllBytes(image_path); var bottleneck_values = run_bottleneck_on_image( sess, image_data, jpeg_data_tensor, decoded_image_tensor, resized_input_tensor, bottleneck_tensor); var values = bottleneck_values.Data(); var bottleneck_string = string.Join(",", values); File.WriteAllText(bottleneck_path, bottleneck_string); } /// /// Runs inference on an image to extract the 'bottleneck' summary layer. /// /// Current active TensorFlow Session. /// Data of raw JPEG data. /// Input data layer in the graph. /// Output of initial image resizing and preprocessing. /// The input node of the recognition graph. /// Layer before the final softmax. /// private NDArray run_bottleneck_on_image(Session sess, byte[] image_data, Tensor image_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor) { // First decode the JPEG image, resize it, and rescale the pixel values. var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, image_data)); // Then run it through the recognition network. var bottleneck_values = sess.run(bottleneck_tensor, new FeedItem(resized_input_tensor, resized_input_values)); bottleneck_values = np.squeeze(bottleneck_values); return bottleneck_values; } private string get_bottleneck_path(Dictionary> image_lists, string label_name, int index, string bottleneck_dir, string category, string module_name) { module_name = (module_name.Replace("://", "~") // URL scheme. .Replace('/', '~') // URL and Unix paths. .Replace(':', '~').Replace('\\', '~')); // Windows paths. return get_image_path(image_lists, label_name, index, bottleneck_dir, category) + "_" + module_name + ".txt"; } private string get_image_path(Dictionary> image_lists, string label_name, int index, string image_dir, string category) { if (!image_lists.ContainsKey(label_name)) print($"Label does not exist {label_name}"); var label_lists = image_lists[label_name]; if (!label_lists.ContainsKey(category)) print($"Category does not exist {category}"); var category_list = label_lists[category]; if (category_list.Length == 0) print($"Label {label_name} has no images in the category {category}."); var mod_index = index % len(category_list); var base_name = category_list[mod_index].Split(Path.DirectorySeparatorChar).Last(); var sub_dir = label_name; var full_path = Path.Join(image_dir, sub_dir, base_name); return full_path; } public void PrepareData() { // get a set of images to teach the network about the new classes string fileName = "flower_photos.tgz"; string url = $"http://download.tensorflow.org/models/{fileName}"; Web.Download(url, data_dir, fileName); Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); // download graph meta data url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta"; Web.Download(url, "graph", "InceptionV3.meta"); // Prepare necessary directories that can be used during training Directory.CreateDirectory(summaries_dir); Directory.CreateDirectory(bottleneck_dir); // Look at the folder structure, and create lists of all the images. image_lists = create_image_lists(); var class_count = len(image_lists); if (class_count == 0) print($"No valid folders of images found at {image_dir}"); if (class_count == 1) print("Only one valid folder of images found at " + image_dir + " - multiple classes are needed for classification."); } private (Tensor, Tensor) add_jpeg_decoding() { // height, width, depth var input_dim = (299, 299, 3); var jpeg_data = tf.placeholder(tf.chars, name: "DecodeJPGInput"); var decoded_image = tf.image.decode_jpeg(jpeg_data, channels: input_dim.Item3); // Convert from full range of uint8 to range [0,1] of float32. var decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, tf.float32); var decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0); var resize_shape = tf.stack(new int[] { input_dim.Item1, input_dim.Item2 }); var resize_shape_as_int = tf.cast(resize_shape, dtype: tf.int32); var resized_image = tf.image.resize_bilinear(decoded_image_4d, resize_shape_as_int); return (jpeg_data, resized_image); } /// /// Builds a list of training images from the file system. /// private Dictionary> create_image_lists() { var sub_dirs = tf.gfile.Walk(image_dir) .Select(x => x.Item1) .OrderBy(x => x) .ToArray(); var result = new Dictionary>(); foreach(var sub_dir in sub_dirs) { var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last(); print($"Looking for images in '{dir_name}'"); var file_list = Directory.GetFiles(sub_dir); if (len(file_list) < 20) print($"WARNING: Folder has less than 20 images, which may cause issues."); var label_name = dir_name.ToLower(); result[label_name] = new Dictionary(); int testing_count = (int)Math.Floor(file_list.Length * testing_percentage); int validation_count = (int)Math.Floor(file_list.Length * validation_percentage); result[label_name]["testing"] = file_list.Take(testing_count).ToArray(); result[label_name]["validation"] = file_list.Skip(testing_count).Take(validation_count).ToArray(); result[label_name]["training"] = file_list.Skip(testing_count + validation_count).ToArray(); } return result; } } }