You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RetrainImageClassifier.cs 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694
  1. using Google.Protobuf;
  2. using NumSharp;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Drawing;
  7. using System.IO;
  8. using System.Linq;
  9. using System.Text;
  10. using Tensorflow;
  11. using TensorFlowNET.Examples.Utility;
  12. using static Tensorflow.Python;
  13. using Console = Colorful.Console;
  14. namespace TensorFlowNET.Examples.ImageProcess
  15. {
  16. /// <summary>
  17. /// In this tutorial, we will reuse the feature extraction capabilities from powerful image classifiers trained on ImageNet
  18. /// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this
  19. /// by taking a piece of a model that has already been trained on a related task and reusing it in a new model.
  20. ///
  21. /// https://www.tensorflow.org/hub/tutorials/image_retraining
  22. /// </summary>
  23. public class RetrainImageClassifier : IExample
  24. {
  25. public int Priority => 16;
  26. public bool Enabled { get; set; } = true;
  27. public bool IsImportingGraph { get; set; } = true;
  28. public string Name => "Retrain Image Classifier";
  29. const string data_dir = "retrain_images";
  30. string summaries_dir = Path.Join(data_dir, "retrain_logs");
  31. string image_dir = Path.Join(data_dir, "flower_photos");
  32. string bottleneck_dir = Path.Join(data_dir, "bottleneck");
  33. string output_graph = Path.Join(data_dir, "output_graph.pb");
  34. string output_labels = Path.Join(data_dir, "output_labels.txt");
  35. // The location where variable checkpoints will be stored.
  36. string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint");
  37. string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
  38. string final_tensor_name = "final_result";
  39. float testing_percentage = 0.1f;
  40. float validation_percentage = 0.1f;
  41. float learning_rate = 0.01f;
  42. Tensor resized_image_tensor;
  43. Dictionary<string, Dictionary<string, string[]>> image_lists;
  44. int how_many_training_steps = 100;
  45. int eval_step_interval = 10;
  46. int train_batch_size = 100;
  47. int test_batch_size = -1;
  48. int validation_batch_size = 100;
  49. int intermediate_store_frequency = 0;
  50. int class_count = 0;
  51. const int MAX_NUM_IMAGES_PER_CLASS = 134217727;
  52. Operation train_step;
  53. Tensor final_tensor;
  54. Tensor bottleneck_input;
  55. Tensor cross_entropy;
  56. Tensor ground_truth_input;
  57. public bool Run()
  58. {
  59. PrepareData();
  60. // Set up the pre-trained graph.
  61. var (graph, bottleneck_tensor, resized_image_tensor, wants_quantization) =
  62. create_module_graph();
  63. // Add the new layer that we'll be training.
  64. with(graph.as_default(), delegate
  65. {
  66. (train_step, cross_entropy, bottleneck_input,
  67. ground_truth_input, final_tensor) = add_final_retrain_ops(
  68. class_count, final_tensor_name, bottleneck_tensor,
  69. wants_quantization, is_training: true);
  70. });
  71. var sw = new Stopwatch();
  72. return with(tf.Session(graph), sess =>
  73. {
  74. // Initialize all weights: for the module to their pretrained values,
  75. // and for the newly added retraining layer to random initial values.
  76. var init = tf.global_variables_initializer();
  77. sess.run(init);
  78. var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();
  79. // We'll make sure we've calculated the 'bottleneck' image summaries and
  80. // cached them on disk.
  81. cache_bottlenecks(sess, image_lists, image_dir,
  82. bottleneck_dir, jpeg_data_tensor,
  83. decoded_image_tensor, resized_image_tensor,
  84. bottleneck_tensor, tfhub_module);
  85. // Create the operations we need to evaluate the accuracy of our new layer.
  86. var (evaluation_step, _) = add_evaluation_step(final_tensor, ground_truth_input);
  87. // Merge all the summaries and write them out to the summaries_dir
  88. var merged = tf.summary.merge_all();
  89. var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph);
  90. var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph);
  91. // Create a train saver that is used to restore values into an eval graph
  92. // when exporting models.
  93. var train_saver = tf.train.Saver();
  94. train_saver.save(sess, CHECKPOINT_NAME);
  95. sw.Restart();
  96. for (int i = 0; i < how_many_training_steps; i++)
  97. {
  98. var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks(
  99. sess, image_lists, train_batch_size, "training",
  100. bottleneck_dir, image_dir, jpeg_data_tensor,
  101. decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
  102. tfhub_module);
  103. // Feed the bottlenecks and ground truth into the graph, and run a training
  104. // step. Capture training summaries for TensorBoard with the `merged` op.
  105. var results = sess.run(
  106. new ITensorOrOperation[] { merged, train_step },
  107. new FeedItem(bottleneck_input, train_bottlenecks),
  108. new FeedItem(ground_truth_input, train_ground_truth));
  109. var train_summary = results[0];
  110. // TODO
  111. train_writer.add_summary(train_summary, i);
  112. // Every so often, print out how well the graph is training.
  113. bool is_last_step = (i + 1 == how_many_training_steps);
  114. if ((i % eval_step_interval) == 0 || is_last_step)
  115. {
  116. results = sess.run(
  117. new Tensor[] { evaluation_step, cross_entropy },
  118. new FeedItem(bottleneck_input, train_bottlenecks),
  119. new FeedItem(ground_truth_input, train_ground_truth));
  120. (float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
  121. print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}");
  122. var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
  123. sess, image_lists, validation_batch_size, "validation",
  124. bottleneck_dir, image_dir, jpeg_data_tensor,
  125. decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
  126. tfhub_module);
  127. // Run a validation step and capture training summaries for TensorBoard
  128. // with the `merged` op.
  129. results = sess.run(new Tensor[] { merged, evaluation_step },
  130. new FeedItem(bottleneck_input, validation_bottlenecks),
  131. new FeedItem(ground_truth_input, validation_ground_truth));
  132. (string validation_summary, float validation_accuracy) = (results[0], results[1]);
  133. validation_writer.add_summary(validation_summary, i);
  134. print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms");
  135. sw.Restart();
  136. }
  137. // Store intermediate results
  138. int intermediate_frequency = intermediate_store_frequency;
  139. if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0)
  140. {
  141. }
  142. }
  143. // After training is complete, force one last save of the train checkpoint.
  144. train_saver.save(sess, CHECKPOINT_NAME);
  145. // We've completed all our training, so run a final test evaluation on
  146. // some new images we haven't used before.
  147. var (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists,
  148. jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
  149. bottleneck_tensor);
  150. // Write out the trained graph and labels with the weights stored as
  151. // constants.
  152. print($"Save final result to : {output_graph}");
  153. save_graph_to_file(output_graph, class_count);
  154. File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys));
  155. return test_accuracy > 0.75f;
  156. });
  157. }
  158. /// <summary>
  159. /// Runs a final evaluation on an eval graph using the test data set.
  160. /// </summary>
  161. /// <param name="train_session"></param>
  162. /// <param name="module_spec"></param>
  163. /// <param name="class_count"></param>
  164. /// <param name="image_lists"></param>
  165. /// <param name="jpeg_data_tensor"></param>
  166. /// <param name="decoded_image_tensor"></param>
  167. /// <param name="resized_image_tensor"></param>
  168. /// <param name="bottleneck_tensor"></param>
  169. private (float, NDArray) run_final_eval(Session train_session, object module_spec, int class_count,
  170. Dictionary<string, Dictionary<string, string[]>> image_lists,
  171. Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
  172. Tensor resized_image_tensor, Tensor bottleneck_tensor)
  173. {
  174. var (test_bottlenecks, test_ground_truth, test_filenames) = get_random_cached_bottlenecks(train_session, image_lists,
  175. test_batch_size, "testing", bottleneck_dir, image_dir, jpeg_data_tensor,
  176. decoded_image_tensor, resized_image_tensor, bottleneck_tensor, tfhub_module);
  177. var (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step,
  178. prediction) = build_eval_session(class_count);
  179. var results = eval_session.run(new Tensor[] { evaluation_step, prediction },
  180. new FeedItem(bottleneck_input, test_bottlenecks),
  181. new FeedItem(ground_truth_input, test_ground_truth));
  182. print($"final test accuracy: {((float)results[0] * 100).ToString("G4")}% (N={len(test_bottlenecks)})");
  183. return (results[0], results[1]);
  184. }
  185. private (Session, Tensor, Tensor, Tensor, Tensor, Tensor)
  186. build_eval_session(int class_count)
  187. {
  188. // If quantized, we need to create the correct eval graph for exporting.
  189. var (eval_graph, bottleneck_tensor, resized_input_tensor, wants_quantization) = create_module_graph();
  190. var eval_sess = tf.Session(graph: eval_graph);
  191. Tensor evaluation_step = null;
  192. Tensor prediction = null;
  193. with(eval_graph.as_default(), graph =>
  194. {
  195. // Add the new layer for exporting.
  196. var (_, _, bottleneck_input, ground_truth_input, final_tensor) =
  197. add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor,
  198. wants_quantization, is_training: false);
  199. // Now we need to restore the values from the training graph to the eval
  200. // graph.
  201. tf.train.Saver().restore(eval_sess, CHECKPOINT_NAME);
  202. (evaluation_step, prediction) = add_evaluation_step(final_tensor,
  203. ground_truth_input);
  204. });
  205. return (eval_sess, resized_input_tensor, bottleneck_input, ground_truth_input,
  206. evaluation_step, prediction);
  207. }
  208. /// <summary>
  209. /// Adds a new softmax and fully-connected layer for training and eval.
  210. ///
  211. /// We need to retrain the top layer to identify our new classes, so this function
  212. /// adds the right operations to the graph, along with some variables to hold the
  213. /// weights, and then sets up all the gradients for the backward pass.
  214. ///
  215. /// The set up for the softmax and fully-connected layers is based on:
  216. /// https://www.tensorflow.org/tutorials/mnist/beginners/index.html
  217. /// </summary>
  218. /// <param name="class_count"></param>
  219. /// <param name="final_tensor_name"></param>
  220. /// <param name="bottleneck_tensor"></param>
  221. /// <param name="quantize_layer"></param>
  222. /// <param name="is_training"></param>
  223. /// <returns></returns>
  224. private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name,
  225. Tensor bottleneck_tensor, bool quantize_layer, bool is_training)
  226. {
  227. var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.GetShape().Dimensions[0], bottleneck_tensor.GetShape().Dimensions[1]);
  228. with(tf.name_scope("input"), scope =>
  229. {
  230. bottleneck_input = tf.placeholder_with_default(
  231. bottleneck_tensor,
  232. shape: bottleneck_tensor.GetShape().Dimensions,
  233. name: "BottleneckInputPlaceholder");
  234. ground_truth_input = tf.placeholder(tf.int64, new TensorShape(batch_size), name: "GroundTruthInput");
  235. });
  236. // Organizing the following ops so they are easier to see in TensorBoard.
  237. string layer_name = "final_retrain_ops";
  238. Tensor logits = null;
  239. with(tf.name_scope(layer_name), scope =>
  240. {
  241. RefVariable layer_weights = null;
  242. with(tf.name_scope("weights"), delegate
  243. {
  244. var initial_value = tf.truncated_normal(new int[] { bottleneck_tensor_size, class_count }, stddev: 0.001f);
  245. layer_weights = tf.Variable(initial_value, name: "final_weights");
  246. variable_summaries(layer_weights);
  247. });
  248. RefVariable layer_biases = null;
  249. with(tf.name_scope("biases"), delegate
  250. {
  251. layer_biases = tf.Variable(tf.zeros((class_count)), name: "final_biases");
  252. variable_summaries(layer_biases);
  253. });
  254. with(tf.name_scope("Wx_plus_b"), delegate
  255. {
  256. logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases;
  257. tf.summary.histogram("pre_activations", logits);
  258. });
  259. });
  260. final_tensor = tf.nn.softmax(logits, name: final_tensor_name);
  261. // The tf.contrib.quantize functions rewrite the graph in place for
  262. // quantization. The imported model graph has already been rewritten, so upon
  263. // calling these rewrites, only the newly added final layer will be
  264. // transformed.
  265. if (quantize_layer)
  266. {
  267. throw new NotImplementedException("quantize_layer");
  268. /*if (is_training)
  269. tf.contrib.quantize.create_training_graph();
  270. else
  271. tf.contrib.quantize.create_eval_graph();*/
  272. }
  273. tf.summary.histogram("activations", final_tensor);
  274. // If this is an eval graph, we don't need to add loss ops or an optimizer.
  275. if (!is_training)
  276. return (null, null, bottleneck_input, ground_truth_input, final_tensor);
  277. Tensor cross_entropy_mean = null;
  278. with(tf.name_scope("cross_entropy"), delegate
  279. {
  280. cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
  281. labels: ground_truth_input, logits: logits);
  282. });
  283. tf.summary.scalar("cross_entropy", cross_entropy_mean);
  284. with(tf.name_scope("train"), delegate
  285. {
  286. var optimizer = tf.train.GradientDescentOptimizer(learning_rate);
  287. train_step = optimizer.minimize(cross_entropy_mean);
  288. });
  289. return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
  290. final_tensor);
  291. }
  292. private void variable_summaries(RefVariable var)
  293. {
  294. with(tf.name_scope("summaries"), delegate
  295. {
  296. var mean = tf.reduce_mean(var);
  297. tf.summary.scalar("mean", mean);
  298. Tensor stddev = null;
  299. with(tf.name_scope("stddev"), delegate {
  300. stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)));
  301. });
  302. tf.summary.scalar("stddev", stddev);
  303. tf.summary.scalar("max", tf.reduce_max(var));
  304. tf.summary.scalar("min", tf.reduce_min(var));
  305. tf.summary.histogram("histogram", var);
  306. });
  307. }
  308. private (Graph, Tensor, Tensor, bool) create_module_graph()
  309. {
  310. var (height, width) = (299, 299);
  311. return with(tf.Graph().as_default(), graph =>
  312. {
  313. tf.train.import_meta_graph("graph/InceptionV3.meta");
  314. Tensor resized_input_tensor = graph.OperationByName("Placeholder"); //tf.placeholder(tf.float32, new TensorShape(-1, height, width, 3));
  315. // var m = hub.Module(module_spec);
  316. Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");// m(resized_input_tensor);
  317. var wants_quantization = false;
  318. return (graph, bottleneck_tensor, resized_input_tensor, wants_quantization);
  319. });
  320. }
  321. private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  322. int how_many, string category, string bottleneck_dir, string image_dir,
  323. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
  324. Tensor bottleneck_tensor, string module_name)
  325. {
  326. var bottlenecks = new List<float[]>();
  327. var ground_truths = new List<long>();
  328. var filenames = new List<string>();
  329. class_count = image_lists.Keys.Count;
  330. if (how_many >= 0)
  331. {
  332. // Retrieve a random sample of bottlenecks.
  333. foreach (var unused_i in range(how_many))
  334. {
  335. int label_index = new Random().Next(class_count);
  336. string label_name = image_lists.Keys.ToArray()[label_index];
  337. int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS);
  338. string image_name = get_image_path(image_lists, label_name, image_index,
  339. image_dir, category);
  340. var bottleneck = get_or_create_bottleneck(
  341. sess, image_lists, label_name, image_index, image_dir, category,
  342. bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
  343. resized_input_tensor, bottleneck_tensor, module_name);
  344. bottlenecks.Add(bottleneck);
  345. ground_truths.Add(label_index);
  346. filenames.Add(image_name);
  347. }
  348. }
  349. else
  350. {
  351. // Retrieve all bottlenecks.
  352. foreach (var (label_index, label_name) in enumerate(image_lists.Keys.ToArray()))
  353. {
  354. foreach(var (image_index, image_name) in enumerate(image_lists[label_name][category]))
  355. {
  356. var bottleneck = get_or_create_bottleneck(
  357. sess, image_lists, label_name, image_index, image_dir, category,
  358. bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
  359. resized_input_tensor, bottleneck_tensor, module_name);
  360. bottlenecks.Add(bottleneck);
  361. ground_truths.Add(label_index);
  362. filenames.Add(image_name);
  363. }
  364. }
  365. }
  366. return (bottlenecks.ToArray(), ground_truths.ToArray(), filenames.ToArray());
  367. }
  368. /// <summary>
  369. /// Inserts the operations we need to evaluate the accuracy of our results.
  370. /// </summary>
  371. /// <param name="result_tensor"></param>
  372. /// <param name="ground_truth_tensor"></param>
  373. /// <returns></returns>
  374. private (Tensor, Tensor) add_evaluation_step(Tensor result_tensor, Tensor ground_truth_tensor)
  375. {
  376. Tensor evaluation_step = null, correct_prediction = null, prediction = null;
  377. with(tf.name_scope("accuracy"), scope =>
  378. {
  379. with(tf.name_scope("correct_prediction"), delegate
  380. {
  381. prediction = tf.argmax(result_tensor, 1);
  382. correct_prediction = tf.equal(prediction, ground_truth_tensor);
  383. });
  384. with(tf.name_scope("accuracy"), delegate
  385. {
  386. evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
  387. });
  388. });
  389. tf.summary.scalar("accuracy", evaluation_step);
  390. return (evaluation_step, prediction);
  391. }
  392. /// <summary>
  393. /// Ensures all the training, testing, and validation bottlenecks are cached.
  394. /// </summary>
  395. /// <param name="sess"></param>
  396. /// <param name="image_lists"></param>
  397. /// <param name="image_dir"></param>
  398. /// <param name="bottleneck_dir"></param>
  399. /// <param name="jpeg_data_tensor"></param>
  400. /// <param name="decoded_image_tensor"></param>
  401. /// <param name="resized_image_tensor"></param>
  402. /// <param name="bottleneck_tensor"></param>
  403. /// <param name="tfhub_module"></param>
  404. private void cache_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  405. string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
  406. Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name)
  407. {
  408. int how_many_bottlenecks = 0;
  409. foreach(var (label_name, label_lists) in image_lists)
  410. {
  411. foreach(var category in new string[] { "training", "testing", "validation" })
  412. {
  413. var category_list = label_lists[category];
  414. foreach(var (index, unused_base_name) in enumerate(category_list))
  415. {
  416. get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category,
  417. bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
  418. resized_input_tensor, bottleneck_tensor, module_name);
  419. how_many_bottlenecks++;
  420. if (how_many_bottlenecks % 300 == 0)
  421. print($"{how_many_bottlenecks} bottleneck files created.");
  422. }
  423. }
  424. }
  425. }
  426. private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  427. string label_name, int index, string image_dir, string category, string bottleneck_dir,
  428. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
  429. Tensor bottleneck_tensor, string module_name)
  430. {
  431. var label_lists = image_lists[label_name];
  432. var sub_dir_path = Path.Join(bottleneck_dir, label_name);
  433. Directory.CreateDirectory(sub_dir_path);
  434. string bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
  435. bottleneck_dir, category, module_name);
  436. if (!File.Exists(bottleneck_path))
  437. create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
  438. image_dir, category, sess, jpeg_data_tensor,
  439. decoded_image_tensor, resized_input_tensor,
  440. bottleneck_tensor);
  441. var bottleneck_string = File.ReadAllText(bottleneck_path);
  442. var bottleneck_values = Array.ConvertAll(bottleneck_string.Split(','), x => float.Parse(x));
  443. return bottleneck_values;
  444. }
  445. private void create_bottleneck_file(string bottleneck_path, Dictionary<string, Dictionary<string, string[]>> image_lists,
  446. string label_name, int index, string image_dir, string category, Session sess,
  447. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
  448. {
  449. // Create a single bottleneck file.
  450. print("Creating bottleneck at " + bottleneck_path);
  451. var image_path = get_image_path(image_lists, label_name, index, image_dir, category);
  452. if (!File.Exists(image_path))
  453. print($"File does not exist {image_path}");
  454. var image_data = File.ReadAllBytes(image_path);
  455. var bottleneck_values = run_bottleneck_on_image(
  456. sess, image_data, jpeg_data_tensor, decoded_image_tensor,
  457. resized_input_tensor, bottleneck_tensor);
  458. var values = bottleneck_values.Data<float>();
  459. var bottleneck_string = string.Join(",", values);
  460. File.WriteAllText(bottleneck_path, bottleneck_string);
  461. }
  462. /// <summary>
  463. /// Runs inference on an image to extract the 'bottleneck' summary layer.
  464. /// </summary>
  465. /// <param name="sess">Current active TensorFlow Session.</param>
  466. /// <param name="image_data">Data of raw JPEG data.</param>
  467. /// <param name="image_data_tensor">Input data layer in the graph.</param>
  468. /// <param name="decoded_image_tensor">Output of initial image resizing and preprocessing.</param>
  469. /// <param name="resized_input_tensor">The input node of the recognition graph.</param>
  470. /// <param name="bottleneck_tensor">Layer before the final softmax.</param>
  471. /// <returns></returns>
  472. private NDArray run_bottleneck_on_image(Session sess, byte[] image_data, Tensor image_data_tensor,
  473. Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
  474. {
  475. // First decode the JPEG image, resize it, and rescale the pixel values.
  476. var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, image_data));
  477. // Then run it through the recognition network.
  478. var bottleneck_values = sess.run(bottleneck_tensor, new FeedItem(resized_input_tensor, resized_input_values));
  479. bottleneck_values = np.squeeze(bottleneck_values);
  480. return bottleneck_values;
  481. }
  482. private string get_bottleneck_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name, int index,
  483. string bottleneck_dir, string category, string module_name)
  484. {
  485. module_name = (module_name.Replace("://", "~") // URL scheme.
  486. .Replace('/', '~') // URL and Unix paths.
  487. .Replace(':', '~').Replace('\\', '~')); // Windows paths.
  488. return get_image_path(image_lists, label_name, index, bottleneck_dir,
  489. category) + "_" + module_name + ".txt";
  490. }
  491. private string get_image_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name,
  492. int index, string image_dir, string category)
  493. {
  494. if (!image_lists.ContainsKey(label_name))
  495. print($"Label does not exist {label_name}");
  496. var label_lists = image_lists[label_name];
  497. if (!label_lists.ContainsKey(category))
  498. print($"Category does not exist {category}");
  499. var category_list = label_lists[category];
  500. if (category_list.Length == 0)
  501. print($"Label {label_name} has no images in the category {category}.");
  502. var mod_index = index % len(category_list);
  503. var base_name = category_list[mod_index].Split(Path.DirectorySeparatorChar).Last();
  504. var sub_dir = label_name;
  505. var full_path = Path.Join(image_dir, sub_dir, base_name);
  506. return full_path;
  507. }
  508. /// <summary>
  509. /// Saves an graph to file, creating a valid quantized one if necessary.
  510. /// </summary>
  511. /// <param name="graph_file_name"></param>
  512. /// <param name="class_count"></param>
  513. private void save_graph_to_file(string graph_file_name, int class_count)
  514. {
  515. var (sess, _, _, _, _, _) = build_eval_session(class_count);
  516. var graph = sess.graph;
  517. var output_graph_def = tf.graph_util.convert_variables_to_constants(
  518. sess, graph.as_graph_def(), new string[] { final_tensor_name });
  519. File.WriteAllBytes(graph_file_name, output_graph_def.ToByteArray());
  520. }
  521. public void PrepareData()
  522. {
  523. // get a set of images to teach the network about the new classes
  524. string fileName = "flower_photos.tgz";
  525. string url = $"http://download.tensorflow.org/example_images/{fileName}";
  526. Web.Download(url, data_dir, fileName);
  527. Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir);
  528. // download graph meta data
  529. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta";
  530. Web.Download(url, "graph", "InceptionV3.meta");
  531. // download variables.data checkpoint file.
  532. url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip";
  533. Web.Download(url, data_dir, "tfhub_modules.zip");
  534. Compress.UnZip(Path.Join(data_dir, "tfhub_modules.zip"), "tfhub_modules");
  535. // Prepare necessary directories that can be used during training
  536. Directory.CreateDirectory(summaries_dir);
  537. Directory.CreateDirectory(bottleneck_dir);
  538. // Look at the folder structure, and create lists of all the images.
  539. image_lists = create_image_lists();
  540. class_count = len(image_lists);
  541. if (class_count == 0)
  542. print($"No valid folders of images found at {image_dir}");
  543. if (class_count == 1)
  544. print("Only one valid folder of images found at " +
  545. image_dir +
  546. " - multiple classes are needed for classification.");
  547. }
  548. private (Tensor, Tensor) add_jpeg_decoding()
  549. {
  550. // height, width, depth
  551. var input_dim = (299, 299, 3);
  552. var jpeg_data = tf.placeholder(tf.chars, name: "DecodeJPGInput");
  553. var decoded_image = tf.image.decode_jpeg(jpeg_data, channels: input_dim.Item3);
  554. // Convert from full range of uint8 to range [0,1] of float32.
  555. var decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, tf.float32);
  556. var decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0);
  557. var resize_shape = tf.stack(new int[] { input_dim.Item1, input_dim.Item2 });
  558. var resize_shape_as_int = tf.cast(resize_shape, dtype: tf.int32);
  559. var resized_image = tf.image.resize_bilinear(decoded_image_4d, resize_shape_as_int);
  560. return (jpeg_data, resized_image);
  561. }
  562. /// <summary>
  563. /// Builds a list of training images from the file system.
  564. /// </summary>
  565. private Dictionary<string, Dictionary<string, string[]>> create_image_lists()
  566. {
  567. var sub_dirs = tf.gfile.Walk(image_dir)
  568. .Select(x => x.Item1)
  569. .OrderBy(x => x)
  570. .ToArray();
  571. var result = new Dictionary<string, Dictionary<string, string[]>>();
  572. foreach(var sub_dir in sub_dirs)
  573. {
  574. var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last();
  575. print($"Looking for images in '{dir_name}'");
  576. var file_list = Directory.GetFiles(sub_dir);
  577. if (len(file_list) < 20)
  578. print($"WARNING: Folder has less than 20 images, which may cause issues.");
  579. var label_name = dir_name.ToLower();
  580. result[label_name] = new Dictionary<string, string[]>();
  581. int testing_count = (int)Math.Floor(file_list.Length * testing_percentage);
  582. int validation_count = (int)Math.Floor(file_list.Length * validation_percentage);
  583. result[label_name]["testing"] = file_list.Take(testing_count).ToArray();
  584. result[label_name]["validation"] = file_list.Skip(testing_count).Take(validation_count).ToArray();
  585. result[label_name]["training"] = file_list.Skip(testing_count + validation_count).ToArray();
  586. }
  587. return result;
  588. }
  589. public Graph ImportGraph()
  590. {
  591. throw new NotImplementedException();
  592. }
  593. public Graph BuildGraph()
  594. {
  595. throw new NotImplementedException();
  596. }
  597. public bool Train()
  598. {
  599. throw new NotImplementedException();
  600. }
  601. public bool Predict()
  602. {
  603. throw new NotImplementedException();
  604. }
  605. }
  606. }