You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RetrainImageClassifier.cs 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Text;
  8. using Tensorflow;
  9. using TensorFlowNET.Examples.Utility;
  10. using static Tensorflow.Python;
  11. namespace TensorFlowNET.Examples.ImageProcess
  12. {
  13. /// <summary>
  14. /// In this tutorial, we will reuse the feature extraction capabilities from powerful image classifiers trained on ImageNet
  15. /// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this
  16. /// by taking a piece of a model that has already been trained on a related task and reusing it in a new model.
  17. ///
  18. /// https://www.tensorflow.org/hub/tutorials/image_retraining
  19. /// </summary>
  20. public class RetrainImageClassifier : IExample
  21. {
  22. public int Priority => 16;
  23. public bool Enabled { get; set; } = false;
  24. public bool ImportGraph { get; set; } = true;
  25. public string Name => "Retrain Image Classifier";
  26. const string data_dir = "retrain_images";
  27. string summaries_dir = Path.Join(data_dir, "retrain_logs");
  28. string image_dir = Path.Join(data_dir, "flower_photos");
  29. string bottleneck_dir = Path.Join(data_dir, "bottleneck");
  30. // The location where variable checkpoints will be stored.
  31. string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint");
  32. string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
  33. float testing_percentage = 0.1f;
  34. float validation_percentage = 0.1f;
  35. Tensor resized_image_tensor;
  36. Dictionary<string, Dictionary<string, string[]>> image_lists;
  37. int how_many_training_steps = 200;
  38. int eval_step_interval = 10;
  39. int train_batch_size = 100;
  40. int validation_batch_size = 100;
  41. int intermediate_store_frequency = 0;
  42. const int MAX_NUM_IMAGES_PER_CLASS = 134217727;
  43. public bool Run()
  44. {
  45. PrepareData();
  46. var graph = tf.Graph().as_default();
  47. tf.train.import_meta_graph("graph/InceptionV3.meta");
  48. Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");
  49. Tensor resized_image_tensor = graph.OperationByName("Placeholder");
  50. Tensor final_tensor = graph.OperationByName("final_result");
  51. Tensor ground_truth_input = graph.OperationByName("input/GroundTruthInput");
  52. Operation train_step = graph.OperationByName("train/GradientDescent");
  53. Tensor bottleneck_input = graph.OperationByName("input/BottleneckInputPlaceholder");
  54. Tensor cross_entropy = graph.OperationByName("cross_entropy/sparse_softmax_cross_entropy_loss/value");
  55. var sw = new Stopwatch();
  56. with(tf.Session(graph), sess =>
  57. {
  58. // Initialize all weights: for the module to their pretrained values,
  59. // and for the newly added retraining layer to random initial values.
  60. var init = tf.global_variables_initializer();
  61. sess.run(init);
  62. var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();
  63. // We'll make sure we've calculated the 'bottleneck' image summaries and
  64. // cached them on disk.
  65. cache_bottlenecks(sess, image_lists, image_dir,
  66. bottleneck_dir, jpeg_data_tensor,
  67. decoded_image_tensor, resized_image_tensor,
  68. bottleneck_tensor, tfhub_module);
  69. // Create the operations we need to evaluate the accuracy of our new layer.
  70. var (evaluation_step, _) = add_evaluation_step(final_tensor, ground_truth_input);
  71. // Merge all the summaries and write them out to the summaries_dir
  72. var merged = tf.summary.merge_all();
  73. var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph);
  74. var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph);
  75. // Create a train saver that is used to restore values into an eval graph
  76. // when exporting models.
  77. var train_saver = tf.train.Saver();
  78. for (int i = 0; i < how_many_training_steps; i++)
  79. {
  80. var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks(
  81. sess, image_lists, train_batch_size, "training",
  82. bottleneck_dir, image_dir, jpeg_data_tensor,
  83. decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
  84. tfhub_module);
  85. // Feed the bottlenecks and ground truth into the graph, and run a training
  86. // step. Capture training summaries for TensorBoard with the `merged` op.
  87. var results = sess.run(
  88. new ITensorOrOperation[] { merged, train_step },
  89. new FeedItem(bottleneck_input, train_bottlenecks),
  90. new FeedItem(ground_truth_input, train_ground_truth));
  91. var train_summary = results[0];
  92. // TODO
  93. train_writer.add_summary(train_summary, i);
  94. // Every so often, print out how well the graph is training.
  95. bool is_last_step = (i + 1 == how_many_training_steps);
  96. if ((i % eval_step_interval) == 0 || is_last_step)
  97. {
  98. results = sess.run(
  99. new Tensor[] { evaluation_step, cross_entropy },
  100. new FeedItem(bottleneck_input, train_bottlenecks),
  101. new FeedItem(ground_truth_input, train_ground_truth));
  102. (float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
  103. print($"{DateTime.Now}: Step {i}: Train accuracy = {train_accuracy * 100}%");
  104. print($"{DateTime.Now}: Step {i}: Cross entropy = {cross_entropy_value}");
  105. var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
  106. sess, image_lists, validation_batch_size, "validation",
  107. bottleneck_dir, image_dir, jpeg_data_tensor,
  108. decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
  109. tfhub_module);
  110. // Run a validation step and capture training summaries for TensorBoard
  111. // with the `merged` op.
  112. results = sess.run(new Tensor[] { merged, evaluation_step },
  113. new FeedItem(bottleneck_input, validation_bottlenecks),
  114. new FeedItem(ground_truth_input, validation_ground_truth));
  115. (string validation_summary, float validation_accuracy) = (results[0], results[1]);
  116. validation_writer.add_summary(validation_summary, i);
  117. print($"{DateTime.Now}: Step {i}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})");
  118. }
  119. // Store intermediate results
  120. int intermediate_frequency = intermediate_store_frequency;
  121. if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0)
  122. {
  123. }
  124. }
  125. // After training is complete, force one last save of the train checkpoint.
  126. train_saver.save(sess, CHECKPOINT_NAME);
  127. });
  128. return false;
  129. }
  130. private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  131. int how_many, string category, string bottleneck_dir, string image_dir,
  132. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
  133. Tensor bottleneck_tensor, string module_name)
  134. {
  135. var bottlenecks = new List<float[]>();
  136. var ground_truths = new List<long>();
  137. var filenames = new List<string>();
  138. int class_count = image_lists.Keys.Count;
  139. foreach (var unused_i in range(how_many))
  140. {
  141. int label_index = new Random().Next(class_count);
  142. string label_name = image_lists.Keys.ToArray()[label_index];
  143. int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS);
  144. string image_name = get_image_path(image_lists, label_name, image_index,
  145. image_dir, category);
  146. var bottleneck = get_or_create_bottleneck(
  147. sess, image_lists, label_name, image_index, image_dir, category,
  148. bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
  149. resized_input_tensor, bottleneck_tensor, module_name);
  150. bottlenecks.Add(bottleneck);
  151. ground_truths.Add(label_index);
  152. filenames.Add(image_name);
  153. }
  154. return (bottlenecks.ToArray(), ground_truths.ToArray(), filenames.ToArray());
  155. }
  156. /// <summary>
  157. /// Inserts the operations we need to evaluate the accuracy of our results.
  158. /// </summary>
  159. /// <param name="result_tensor"></param>
  160. /// <param name="ground_truth_tensor"></param>
  161. /// <returns></returns>
  162. private (Tensor, Tensor) add_evaluation_step(Tensor result_tensor, Tensor ground_truth_tensor)
  163. {
  164. Tensor evaluation_step = null, correct_prediction = null, prediction = null;
  165. with(tf.name_scope("accuracy"), scope =>
  166. {
  167. with(tf.name_scope("correct_prediction"), delegate
  168. {
  169. prediction = tf.argmax(result_tensor, 1);
  170. correct_prediction = tf.equal(prediction, ground_truth_tensor);
  171. });
  172. with(tf.name_scope("accuracy"), delegate
  173. {
  174. evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
  175. });
  176. });
  177. tf.summary.scalar("accuracy", evaluation_step);
  178. return (evaluation_step, prediction);
  179. }
  180. /// <summary>
  181. /// Ensures all the training, testing, and validation bottlenecks are cached.
  182. /// </summary>
  183. /// <param name="sess"></param>
  184. /// <param name="image_lists"></param>
  185. /// <param name="image_dir"></param>
  186. /// <param name="bottleneck_dir"></param>
  187. /// <param name="jpeg_data_tensor"></param>
  188. /// <param name="decoded_image_tensor"></param>
  189. /// <param name="resized_image_tensor"></param>
  190. /// <param name="bottleneck_tensor"></param>
  191. /// <param name="tfhub_module"></param>
  192. private void cache_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  193. string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
  194. Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name)
  195. {
  196. int how_many_bottlenecks = 0;
  197. foreach(var (label_name, label_lists) in image_lists)
  198. {
  199. foreach(var category in new string[] { "training", "testing", "validation" })
  200. {
  201. var category_list = label_lists[category];
  202. foreach(var (index, unused_base_name) in enumerate(category_list))
  203. {
  204. get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category,
  205. bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
  206. resized_input_tensor, bottleneck_tensor, module_name);
  207. how_many_bottlenecks++;
  208. if (how_many_bottlenecks % 100 == 0)
  209. print($"{how_many_bottlenecks} bottleneck files created.");
  210. }
  211. }
  212. }
  213. }
  214. private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  215. string label_name, int index, string image_dir, string category, string bottleneck_dir,
  216. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
  217. Tensor bottleneck_tensor, string module_name)
  218. {
  219. var label_lists = image_lists[label_name];
  220. var sub_dir_path = Path.Join(bottleneck_dir, label_name);
  221. Directory.CreateDirectory(sub_dir_path);
  222. string bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
  223. bottleneck_dir, category, module_name);
  224. if (!File.Exists(bottleneck_path))
  225. create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
  226. image_dir, category, sess, jpeg_data_tensor,
  227. decoded_image_tensor, resized_input_tensor,
  228. bottleneck_tensor);
  229. var bottleneck_string = File.ReadAllText(bottleneck_path);
  230. var bottleneck_values = Array.ConvertAll(bottleneck_string.Split(','), x => float.Parse(x));
  231. return bottleneck_values;
  232. }
  233. private void create_bottleneck_file(string bottleneck_path, Dictionary<string, Dictionary<string, string[]>> image_lists,
  234. string label_name, int index, string image_dir, string category, Session sess,
  235. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
  236. {
  237. // Create a single bottleneck file.
  238. print("Creating bottleneck at " + bottleneck_path);
  239. var image_path = get_image_path(image_lists, label_name, index, image_dir, category);
  240. if (!File.Exists(image_path))
  241. print($"File does not exist {image_path}");
  242. var image_data = File.ReadAllBytes(image_path);
  243. var bottleneck_values = run_bottleneck_on_image(
  244. sess, image_data, jpeg_data_tensor, decoded_image_tensor,
  245. resized_input_tensor, bottleneck_tensor);
  246. var values = bottleneck_values.Data<float>();
  247. var bottleneck_string = string.Join(",", values);
  248. File.WriteAllText(bottleneck_path, bottleneck_string);
  249. }
  250. /// <summary>
  251. /// Runs inference on an image to extract the 'bottleneck' summary layer.
  252. /// </summary>
  253. /// <param name="sess">Current active TensorFlow Session.</param>
  254. /// <param name="image_data">Data of raw JPEG data.</param>
  255. /// <param name="image_data_tensor">Input data layer in the graph.</param>
  256. /// <param name="decoded_image_tensor">Output of initial image resizing and preprocessing.</param>
  257. /// <param name="resized_input_tensor">The input node of the recognition graph.</param>
  258. /// <param name="bottleneck_tensor">Layer before the final softmax.</param>
  259. /// <returns></returns>
  260. private NDArray run_bottleneck_on_image(Session sess, byte[] image_data, Tensor image_data_tensor,
  261. Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
  262. {
  263. // First decode the JPEG image, resize it, and rescale the pixel values.
  264. var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, image_data));
  265. // Then run it through the recognition network.
  266. var bottleneck_values = sess.run(bottleneck_tensor, new FeedItem(resized_input_tensor, resized_input_values));
  267. bottleneck_values = np.squeeze(bottleneck_values);
  268. return bottleneck_values;
  269. }
  270. private string get_bottleneck_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name, int index,
  271. string bottleneck_dir, string category, string module_name)
  272. {
  273. module_name = (module_name.Replace("://", "~") // URL scheme.
  274. .Replace('/', '~') // URL and Unix paths.
  275. .Replace(':', '~').Replace('\\', '~')); // Windows paths.
  276. return get_image_path(image_lists, label_name, index, bottleneck_dir,
  277. category) + "_" + module_name + ".txt";
  278. }
  279. private string get_image_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name,
  280. int index, string image_dir, string category)
  281. {
  282. if (!image_lists.ContainsKey(label_name))
  283. print($"Label does not exist {label_name}");
  284. var label_lists = image_lists[label_name];
  285. if (!label_lists.ContainsKey(category))
  286. print($"Category does not exist {category}");
  287. var category_list = label_lists[category];
  288. if (category_list.Length == 0)
  289. print($"Label {label_name} has no images in the category {category}.");
  290. var mod_index = index % len(category_list);
  291. var base_name = category_list[mod_index].Split(Path.DirectorySeparatorChar).Last();
  292. var sub_dir = label_name;
  293. var full_path = Path.Join(image_dir, sub_dir, base_name);
  294. return full_path;
  295. }
  296. public void PrepareData()
  297. {
  298. // get a set of images to teach the network about the new classes
  299. string fileName = "flower_photos.tgz";
  300. string url = $"http://download.tensorflow.org/models/{fileName}";
  301. Web.Download(url, data_dir, fileName);
  302. Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir);
  303. // download graph meta data
  304. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta";
  305. Web.Download(url, "graph", "InceptionV3.meta");
  306. // Prepare necessary directories that can be used during training
  307. Directory.CreateDirectory(summaries_dir);
  308. Directory.CreateDirectory(bottleneck_dir);
  309. // Look at the folder structure, and create lists of all the images.
  310. image_lists = create_image_lists();
  311. var class_count = len(image_lists);
  312. if (class_count == 0)
  313. print($"No valid folders of images found at {image_dir}");
  314. if (class_count == 1)
  315. print("Only one valid folder of images found at " +
  316. image_dir +
  317. " - multiple classes are needed for classification.");
  318. }
  319. private (Tensor, Tensor) add_jpeg_decoding()
  320. {
  321. // height, width, depth
  322. var input_dim = (299, 299, 3);
  323. var jpeg_data = tf.placeholder(tf.chars, name: "DecodeJPGInput");
  324. var decoded_image = tf.image.decode_jpeg(jpeg_data, channels: input_dim.Item3);
  325. // Convert from full range of uint8 to range [0,1] of float32.
  326. var decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, tf.float32);
  327. var decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0);
  328. var resize_shape = tf.stack(new int[] { input_dim.Item1, input_dim.Item2 });
  329. var resize_shape_as_int = tf.cast(resize_shape, dtype: tf.int32);
  330. var resized_image = tf.image.resize_bilinear(decoded_image_4d, resize_shape_as_int);
  331. return (jpeg_data, resized_image);
  332. }
  333. /// <summary>
  334. /// Builds a list of training images from the file system.
  335. /// </summary>
  336. private Dictionary<string, Dictionary<string, string[]>> create_image_lists()
  337. {
  338. var sub_dirs = tf.gfile.Walk(image_dir)
  339. .Select(x => x.Item1)
  340. .OrderBy(x => x)
  341. .ToArray();
  342. var result = new Dictionary<string, Dictionary<string, string[]>>();
  343. foreach(var sub_dir in sub_dirs)
  344. {
  345. var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last();
  346. print($"Looking for images in '{dir_name}'");
  347. var file_list = Directory.GetFiles(sub_dir);
  348. if (len(file_list) < 20)
  349. print($"WARNING: Folder has less than 20 images, which may cause issues.");
  350. var label_name = dir_name.ToLower();
  351. result[label_name] = new Dictionary<string, string[]>();
  352. int testing_count = (int)Math.Floor(file_list.Length * testing_percentage);
  353. int validation_count = (int)Math.Floor(file_list.Length * validation_percentage);
  354. result[label_name]["testing"] = file_list.Take(testing_count).ToArray();
  355. result[label_name]["validation"] = file_list.Skip(testing_count).Take(validation_count).ToArray();
  356. result[label_name]["training"] = file_list.Skip(testing_count + validation_count).ToArray();
  357. }
  358. return result;
  359. }
  360. }
  361. }