You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RetrainImageClassifier.cs 35 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using Google.Protobuf;
  14. using NumSharp;
  15. using System;
  16. using System.Collections.Generic;
  17. using System.Diagnostics;
  18. using System.Drawing;
  19. using System.IO;
  20. using System.Linq;
  21. using System.Text;
  22. using Tensorflow;
  23. using TensorFlowNET.Examples.Utility;
  24. using static Tensorflow.Python;
  25. using Console = Colorful.Console;
  26. namespace TensorFlowNET.Examples.ImageProcess
  27. {
  28. /// <summary>
  29. /// In this tutorial, we will reuse the feature extraction capabilities from powerful image classifiers trained on ImageNet
  30. /// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this
  31. /// by taking a piece of a model that has already been trained on a related task and reusing it in a new model.
  32. ///
  33. /// https://www.tensorflow.org/hub/tutorials/image_retraining
  34. /// </summary>
  35. public class RetrainImageClassifier : IExample
  36. {
  37. public int Priority => 16;
  38. public bool Enabled { get; set; } = true;
  39. public bool IsImportingGraph { get; set; } = true;
  40. public string Name => "Retrain Image Classifier";
  41. const string data_dir = "retrain_images";
  42. string summaries_dir = Path.Join(data_dir, "retrain_logs");
  43. string image_dir = Path.Join(data_dir, "flower_photos");
  44. string bottleneck_dir = Path.Join(data_dir, "bottleneck");
  45. string output_graph = Path.Join(data_dir, "output_graph.pb");
  46. string output_labels = Path.Join(data_dir, "output_labels.txt");
  47. // The location where variable checkpoints will be stored.
  48. string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint");
  49. string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
  50. string final_tensor_name = "final_result";
  51. float testing_percentage = 0.1f;
  52. float validation_percentage = 0.1f;
  53. float learning_rate = 0.01f;
  54. Tensor resized_image_tensor;
  55. Dictionary<string, Dictionary<string, string[]>> image_lists;
  56. int how_many_training_steps = 100;
  57. int eval_step_interval = 10;
  58. int train_batch_size = 100;
  59. int test_batch_size = -1;
  60. int validation_batch_size = 100;
  61. int intermediate_store_frequency = 0;
  62. int class_count = 0;
  63. const int MAX_NUM_IMAGES_PER_CLASS = 134217727;
  64. Operation train_step;
  65. Tensor final_tensor;
  66. Tensor bottleneck_input;
  67. Tensor cross_entropy;
  68. Tensor ground_truth_input;
  69. Tensor bottleneck_tensor;
  70. bool wants_quantization;
  71. float test_accuracy;
  72. NDArray predictions;
  73. public bool Run()
  74. {
  75. PrepareData();
  76. var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
  77. with(tf.Session(graph), sess =>
  78. {
  79. Train(sess);
  80. });
  81. return test_accuracy > 0.75f;
  82. }
  83. /// <summary>
  84. /// Runs a final evaluation on an eval graph using the test data set.
  85. /// </summary>
  86. /// <param name="train_session"></param>
  87. /// <param name="module_spec"></param>
  88. /// <param name="class_count"></param>
  89. /// <param name="image_lists"></param>
  90. /// <param name="jpeg_data_tensor"></param>
  91. /// <param name="decoded_image_tensor"></param>
  92. /// <param name="resized_image_tensor"></param>
  93. /// <param name="bottleneck_tensor"></param>
  94. private (float, NDArray) run_final_eval(Session train_session, object module_spec, int class_count,
  95. Dictionary<string, Dictionary<string, string[]>> image_lists,
  96. Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
  97. Tensor resized_image_tensor, Tensor bottleneck_tensor)
  98. {
  99. var (test_bottlenecks, test_ground_truth, test_filenames) = get_random_cached_bottlenecks(train_session, image_lists,
  100. test_batch_size, "testing", bottleneck_dir, image_dir, jpeg_data_tensor,
  101. decoded_image_tensor, resized_image_tensor, bottleneck_tensor, tfhub_module);
  102. var (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step,
  103. prediction) = build_eval_session(class_count);
  104. var results = eval_session.run(new Tensor[] { evaluation_step, prediction },
  105. new FeedItem(bottleneck_input, test_bottlenecks),
  106. new FeedItem(ground_truth_input, test_ground_truth));
  107. print($"final test accuracy: {((float)results[0] * 100).ToString("G4")}% (N={len(test_bottlenecks)})");
  108. return (results[0], results[1]);
  109. }
  110. private (Session, Tensor, Tensor, Tensor, Tensor, Tensor)
  111. build_eval_session(int class_count)
  112. {
  113. // If quantized, we need to create the correct eval graph for exporting.
  114. var (eval_graph, bottleneck_tensor, resized_input_tensor, wants_quantization) = create_module_graph();
  115. var eval_sess = tf.Session(graph: eval_graph);
  116. Tensor evaluation_step = null;
  117. Tensor prediction = null;
  118. with(eval_graph.as_default(), graph =>
  119. {
  120. // Add the new layer for exporting.
  121. var (_, _, bottleneck_input, ground_truth_input, final_tensor) =
  122. add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor,
  123. wants_quantization, is_training: false);
  124. // Now we need to restore the values from the training graph to the eval
  125. // graph.
  126. tf.train.Saver().restore(eval_sess, CHECKPOINT_NAME);
  127. (evaluation_step, prediction) = add_evaluation_step(final_tensor,
  128. ground_truth_input);
  129. });
  130. return (eval_sess, resized_input_tensor, bottleneck_input, ground_truth_input,
  131. evaluation_step, prediction);
  132. }
  133. /// <summary>
  134. /// Adds a new softmax and fully-connected layer for training and eval.
  135. ///
  136. /// We need to retrain the top layer to identify our new classes, so this function
  137. /// adds the right operations to the graph, along with some variables to hold the
  138. /// weights, and then sets up all the gradients for the backward pass.
  139. ///
  140. /// The set up for the softmax and fully-connected layers is based on:
  141. /// https://www.tensorflow.org/tutorials/mnist/beginners/index.html
  142. /// </summary>
  143. /// <param name="class_count"></param>
  144. /// <param name="final_tensor_name"></param>
  145. /// <param name="bottleneck_tensor"></param>
  146. /// <param name="quantize_layer"></param>
  147. /// <param name="is_training"></param>
  148. /// <returns></returns>
  149. private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name,
  150. Tensor bottleneck_tensor, bool quantize_layer, bool is_training)
  151. {
  152. var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.TensorShape.Dimensions[0], bottleneck_tensor.TensorShape.Dimensions[1]);
  153. with(tf.name_scope("input"), scope =>
  154. {
  155. bottleneck_input = tf.placeholder_with_default(
  156. bottleneck_tensor,
  157. shape: bottleneck_tensor.TensorShape.Dimensions,
  158. name: "BottleneckInputPlaceholder");
  159. ground_truth_input = tf.placeholder(tf.int64, new TensorShape(batch_size), name: "GroundTruthInput");
  160. });
  161. // Organizing the following ops so they are easier to see in TensorBoard.
  162. string layer_name = "final_retrain_ops";
  163. Tensor logits = null;
  164. with(tf.name_scope(layer_name), scope =>
  165. {
  166. RefVariable layer_weights = null;
  167. with(tf.name_scope("weights"), delegate
  168. {
  169. var initial_value = tf.truncated_normal(new int[] { bottleneck_tensor_size, class_count }, stddev: 0.001f);
  170. layer_weights = tf.Variable(initial_value, name: "final_weights");
  171. variable_summaries(layer_weights);
  172. });
  173. RefVariable layer_biases = null;
  174. with(tf.name_scope("biases"), delegate
  175. {
  176. layer_biases = tf.Variable(tf.zeros((class_count)), name: "final_biases");
  177. variable_summaries(layer_biases);
  178. });
  179. with(tf.name_scope("Wx_plus_b"), delegate
  180. {
  181. logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases;
  182. tf.summary.histogram("pre_activations", logits);
  183. });
  184. });
  185. final_tensor = tf.nn.softmax(logits, name: final_tensor_name);
  186. // The tf.contrib.quantize functions rewrite the graph in place for
  187. // quantization. The imported model graph has already been rewritten, so upon
  188. // calling these rewrites, only the newly added final layer will be
  189. // transformed.
  190. if (quantize_layer)
  191. {
  192. throw new NotImplementedException("quantize_layer");
  193. /*if (is_training)
  194. tf.contrib.quantize.create_training_graph();
  195. else
  196. tf.contrib.quantize.create_eval_graph();*/
  197. }
  198. tf.summary.histogram("activations", final_tensor);
  199. // If this is an eval graph, we don't need to add loss ops or an optimizer.
  200. if (!is_training)
  201. return (null, null, bottleneck_input, ground_truth_input, final_tensor);
  202. Tensor cross_entropy_mean = null;
  203. with(tf.name_scope("cross_entropy"), delegate
  204. {
  205. cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
  206. labels: ground_truth_input, logits: logits);
  207. });
  208. tf.summary.scalar("cross_entropy", cross_entropy_mean);
  209. with(tf.name_scope("train"), delegate
  210. {
  211. var optimizer = tf.train.GradientDescentOptimizer(learning_rate);
  212. train_step = optimizer.minimize(cross_entropy_mean);
  213. });
  214. return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
  215. final_tensor);
  216. }
  217. private void variable_summaries(RefVariable var)
  218. {
  219. with(tf.name_scope("summaries"), delegate
  220. {
  221. var mean = tf.reduce_mean(var);
  222. tf.summary.scalar("mean", mean);
  223. Tensor stddev = null;
  224. with(tf.name_scope("stddev"), delegate
  225. {
  226. stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)));
  227. });
  228. tf.summary.scalar("stddev", stddev);
  229. tf.summary.scalar("max", tf.reduce_max(var));
  230. tf.summary.scalar("min", tf.reduce_min(var));
  231. tf.summary.histogram("histogram", var);
  232. });
  233. }
  234. private (Graph, Tensor, Tensor, bool) create_module_graph()
  235. {
  236. var (height, width) = (299, 299);
  237. return with(tf.Graph().as_default(), graph =>
  238. {
  239. tf.train.import_meta_graph("graph/InceptionV3.meta");
  240. Tensor resized_input_tensor = graph.OperationByName("Placeholder"); //tf.placeholder(tf.float32, new TensorShape(-1, height, width, 3));
  241. // var m = hub.Module(module_spec);
  242. Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");// m(resized_input_tensor);
  243. var wants_quantization = false;
  244. return (graph, bottleneck_tensor, resized_input_tensor, wants_quantization);
  245. });
  246. }
  247. private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  248. int how_many, string category, string bottleneck_dir, string image_dir,
  249. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
  250. Tensor bottleneck_tensor, string module_name)
  251. {
  252. var bottlenecks = new List<float[]>();
  253. var ground_truths = new List<long>();
  254. var filenames = new List<string>();
  255. class_count = image_lists.Keys.Count;
  256. if (how_many >= 0)
  257. {
  258. // Retrieve a random sample of bottlenecks.
  259. foreach (var unused_i in range(how_many))
  260. {
  261. int label_index = new Random().Next(class_count);
  262. string label_name = image_lists.Keys.ToArray()[label_index];
  263. int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS);
  264. string image_name = get_image_path(image_lists, label_name, image_index,
  265. image_dir, category);
  266. var bottleneck = get_or_create_bottleneck(
  267. sess, image_lists, label_name, image_index, image_dir, category,
  268. bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
  269. resized_input_tensor, bottleneck_tensor, module_name);
  270. bottlenecks.Add(bottleneck);
  271. ground_truths.Add(label_index);
  272. filenames.Add(image_name);
  273. }
  274. }
  275. else
  276. {
  277. // Retrieve all bottlenecks.
  278. foreach (var (label_index, label_name) in enumerate(image_lists.Keys.ToArray()))
  279. {
  280. foreach (var (image_index, image_name) in enumerate(image_lists[label_name][category]))
  281. {
  282. var bottleneck = get_or_create_bottleneck(
  283. sess, image_lists, label_name, image_index, image_dir, category,
  284. bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
  285. resized_input_tensor, bottleneck_tensor, module_name);
  286. bottlenecks.Add(bottleneck);
  287. ground_truths.Add(label_index);
  288. filenames.Add(image_name);
  289. }
  290. }
  291. }
  292. return (bottlenecks.ToArray(), ground_truths.ToArray(), filenames.ToArray());
  293. }
  294. /// <summary>
  295. /// Inserts the operations we need to evaluate the accuracy of our results.
  296. /// </summary>
  297. /// <param name="result_tensor"></param>
  298. /// <param name="ground_truth_tensor"></param>
  299. /// <returns></returns>
  300. private (Tensor, Tensor) add_evaluation_step(Tensor result_tensor, Tensor ground_truth_tensor)
  301. {
  302. Tensor evaluation_step = null, correct_prediction = null, prediction = null;
  303. with(tf.name_scope("accuracy"), scope =>
  304. {
  305. with(tf.name_scope("correct_prediction"), delegate
  306. {
  307. prediction = tf.argmax(result_tensor, 1);
  308. correct_prediction = tf.equal(prediction, ground_truth_tensor);
  309. });
  310. with(tf.name_scope("accuracy"), delegate
  311. {
  312. evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
  313. });
  314. });
  315. tf.summary.scalar("accuracy", evaluation_step);
  316. return (evaluation_step, prediction);
  317. }
  318. /// <summary>
  319. /// Ensures all the training, testing, and validation bottlenecks are cached.
  320. /// </summary>
  321. /// <param name="sess"></param>
  322. /// <param name="image_lists"></param>
  323. /// <param name="image_dir"></param>
  324. /// <param name="bottleneck_dir"></param>
  325. /// <param name="jpeg_data_tensor"></param>
  326. /// <param name="decoded_image_tensor"></param>
  327. /// <param name="resized_image_tensor"></param>
  328. /// <param name="bottleneck_tensor"></param>
  329. /// <param name="tfhub_module"></param>
  330. private void cache_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  331. string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
  332. Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name)
  333. {
  334. int how_many_bottlenecks = 0;
  335. foreach (var (label_name, label_lists) in image_lists)
  336. {
  337. foreach (var category in new string[] { "training", "testing", "validation" })
  338. {
  339. var category_list = label_lists[category];
  340. foreach (var (index, unused_base_name) in enumerate(category_list))
  341. {
  342. get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category,
  343. bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
  344. resized_input_tensor, bottleneck_tensor, module_name);
  345. how_many_bottlenecks++;
  346. if (how_many_bottlenecks % 300 == 0)
  347. print($"{how_many_bottlenecks} bottleneck files created.");
  348. }
  349. }
  350. }
  351. }
  352. private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  353. string label_name, int index, string image_dir, string category, string bottleneck_dir,
  354. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
  355. Tensor bottleneck_tensor, string module_name)
  356. {
  357. var label_lists = image_lists[label_name];
  358. var sub_dir_path = Path.Join(bottleneck_dir, label_name);
  359. Directory.CreateDirectory(sub_dir_path);
  360. string bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
  361. bottleneck_dir, category, module_name);
  362. if (!File.Exists(bottleneck_path))
  363. create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
  364. image_dir, category, sess, jpeg_data_tensor,
  365. decoded_image_tensor, resized_input_tensor,
  366. bottleneck_tensor);
  367. var bottleneck_string = File.ReadAllText(bottleneck_path);
  368. var bottleneck_values = Array.ConvertAll(bottleneck_string.Split(','), x => float.Parse(x));
  369. return bottleneck_values;
  370. }
  371. private void create_bottleneck_file(string bottleneck_path, Dictionary<string, Dictionary<string, string[]>> image_lists,
  372. string label_name, int index, string image_dir, string category, Session sess,
  373. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
  374. {
  375. // Create a single bottleneck file.
  376. print("Creating bottleneck at " + bottleneck_path);
  377. var image_path = get_image_path(image_lists, label_name, index, image_dir, category);
  378. if (!File.Exists(image_path))
  379. print($"File does not exist {image_path}");
  380. var image_data = File.ReadAllBytes(image_path);
  381. var bottleneck_values = run_bottleneck_on_image(
  382. sess, image_data, jpeg_data_tensor, decoded_image_tensor,
  383. resized_input_tensor, bottleneck_tensor);
  384. var values = bottleneck_values.Data<float>();
  385. var bottleneck_string = string.Join(",", values);
  386. File.WriteAllText(bottleneck_path, bottleneck_string);
  387. }
  388. /// <summary>
  389. /// Runs inference on an image to extract the 'bottleneck' summary layer.
  390. /// </summary>
  391. /// <param name="sess">Current active TensorFlow Session.</param>
  392. /// <param name="image_data">Data of raw JPEG data.</param>
  393. /// <param name="image_data_tensor">Input data layer in the graph.</param>
  394. /// <param name="decoded_image_tensor">Output of initial image resizing and preprocessing.</param>
  395. /// <param name="resized_input_tensor">The input node of the recognition graph.</param>
  396. /// <param name="bottleneck_tensor">Layer before the final softmax.</param>
  397. /// <returns></returns>
  398. private NDArray run_bottleneck_on_image(Session sess, byte[] image_data, Tensor image_data_tensor,
  399. Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
  400. {
  401. // First decode the JPEG image, resize it, and rescale the pixel values.
  402. var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, new Tensor(image_data, TF_DataType.TF_STRING)));
  403. // Then run it through the recognition network.
  404. var bottleneck_values = sess.run(bottleneck_tensor, new FeedItem(resized_input_tensor, resized_input_values));
  405. bottleneck_values = np.squeeze(bottleneck_values);
  406. return bottleneck_values;
  407. }
  408. private string get_bottleneck_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name, int index,
  409. string bottleneck_dir, string category, string module_name)
  410. {
  411. module_name = (module_name.Replace("://", "~") // URL scheme.
  412. .Replace('/', '~') // URL and Unix paths.
  413. .Replace(':', '~').Replace('\\', '~')); // Windows paths.
  414. return get_image_path(image_lists, label_name, index, bottleneck_dir,
  415. category) + "_" + module_name + ".txt";
  416. }
  417. private string get_image_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name,
  418. int index, string image_dir, string category)
  419. {
  420. if (!image_lists.ContainsKey(label_name))
  421. print($"Label does not exist {label_name}");
  422. var label_lists = image_lists[label_name];
  423. if (!label_lists.ContainsKey(category))
  424. print($"Category does not exist {category}");
  425. var category_list = label_lists[category];
  426. if (category_list.Length == 0)
  427. print($"Label {label_name} has no images in the category {category}.");
  428. var mod_index = index % len(category_list);
  429. var base_name = category_list[mod_index].Split(Path.DirectorySeparatorChar).Last();
  430. var sub_dir = label_name;
  431. var full_path = Path.Join(image_dir, sub_dir, base_name);
  432. return full_path;
  433. }
  434. /// <summary>
  435. /// Saves an graph to file, creating a valid quantized one if necessary.
  436. /// </summary>
  437. /// <param name="graph_file_name"></param>
  438. /// <param name="class_count"></param>
  439. private void save_graph_to_file(string graph_file_name, int class_count)
  440. {
  441. var (sess, _, _, _, _, _) = build_eval_session(class_count);
  442. var graph = sess.graph;
  443. var output_graph_def = tf.graph_util.convert_variables_to_constants(
  444. sess, graph.as_graph_def(), new string[] { final_tensor_name });
  445. File.WriteAllBytes(graph_file_name, output_graph_def.ToByteArray());
  446. }
  447. public void PrepareData()
  448. {
  449. // get a set of images to teach the network about the new classes
  450. string fileName = "flower_photos.tgz";
  451. string url = $"http://download.tensorflow.org/example_images/{fileName}";
  452. Web.Download(url, data_dir, fileName);
  453. Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir);
  454. // download graph meta data
  455. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta";
  456. Web.Download(url, "graph", "InceptionV3.meta");
  457. // download variables.data checkpoint file.
  458. url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip";
  459. Web.Download(url, data_dir, "tfhub_modules.zip");
  460. Compress.UnZip(Path.Join(data_dir, "tfhub_modules.zip"), "tfhub_modules");
  461. // Prepare necessary directories that can be used during training
  462. Directory.CreateDirectory(summaries_dir);
  463. Directory.CreateDirectory(bottleneck_dir);
  464. // Look at the folder structure, and create lists of all the images.
  465. image_lists = create_image_lists();
  466. class_count = len(image_lists);
  467. if (class_count == 0)
  468. print($"No valid folders of images found at {image_dir}");
  469. if (class_count == 1)
  470. print("Only one valid folder of images found at " +
  471. image_dir +
  472. " - multiple classes are needed for classification.");
  473. }
  474. private (Tensor, Tensor) add_jpeg_decoding()
  475. {
  476. // height, width, depth
  477. var input_dim = (299, 299, 3);
  478. var jpeg_data = tf.placeholder(tf.@string, name: "DecodeJPGInput");
  479. var decoded_image = tf.image.decode_jpeg(jpeg_data, channels: input_dim.Item3);
  480. // Convert from full range of uint8 to range [0,1] of float32.
  481. var decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, tf.float32);
  482. var decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0);
  483. var resize_shape = tf.stack(new int[] { input_dim.Item1, input_dim.Item2 });
  484. var resize_shape_as_int = tf.cast(resize_shape, dtype: tf.int32);
  485. var resized_image = tf.image.resize_bilinear(decoded_image_4d, resize_shape_as_int);
  486. return (jpeg_data, resized_image);
  487. }
  488. /// <summary>
  489. /// Builds a list of training images from the file system.
  490. /// </summary>
  491. private Dictionary<string, Dictionary<string, string[]>> create_image_lists()
  492. {
  493. var sub_dirs = tf.gfile.Walk(image_dir)
  494. .Select(x => x.Item1)
  495. .OrderBy(x => x)
  496. .ToArray();
  497. var result = new Dictionary<string, Dictionary<string, string[]>>();
  498. foreach (var sub_dir in sub_dirs)
  499. {
  500. var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last();
  501. print($"Looking for images in '{dir_name}'");
  502. var file_list = Directory.GetFiles(sub_dir);
  503. if (len(file_list) < 20)
  504. print($"WARNING: Folder has less than 20 images, which may cause issues.");
  505. var label_name = dir_name.ToLower();
  506. result[label_name] = new Dictionary<string, string[]>();
  507. int testing_count = (int)Math.Floor(file_list.Length * testing_percentage);
  508. int validation_count = (int)Math.Floor(file_list.Length * validation_percentage);
  509. result[label_name]["testing"] = file_list.Take(testing_count).ToArray();
  510. result[label_name]["validation"] = file_list.Skip(testing_count).Take(validation_count).ToArray();
  511. result[label_name]["training"] = file_list.Skip(testing_count + validation_count).ToArray();
  512. }
  513. return result;
  514. }
  515. public Graph ImportGraph()
  516. {
  517. Graph graph;
  518. // Set up the pre-trained graph.
  519. (graph, bottleneck_tensor, resized_image_tensor, wants_quantization) =
  520. create_module_graph();
  521. // Add the new layer that we'll be training.
  522. with(graph.as_default(), delegate
  523. {
  524. (train_step, cross_entropy, bottleneck_input,
  525. ground_truth_input, final_tensor) = add_final_retrain_ops(
  526. class_count, final_tensor_name, bottleneck_tensor,
  527. wants_quantization, is_training: true);
  528. });
  529. return graph;
  530. }
  531. public Graph BuildGraph()
  532. {
  533. throw new NotImplementedException();
  534. }
  535. public void Train(Session sess)
  536. {
  537. var sw = new Stopwatch();
  538. // Initialize all weights: for the module to their pretrained values,
  539. // and for the newly added retraining layer to random initial values.
  540. var init = tf.global_variables_initializer();
  541. sess.run(init);
  542. var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();
  543. // We'll make sure we've calculated the 'bottleneck' image summaries and
  544. // cached them on disk.
  545. cache_bottlenecks(sess, image_lists, image_dir,
  546. bottleneck_dir, jpeg_data_tensor,
  547. decoded_image_tensor, resized_image_tensor,
  548. bottleneck_tensor, tfhub_module);
  549. // Create the operations we need to evaluate the accuracy of our new layer.
  550. var (evaluation_step, _) = add_evaluation_step(final_tensor, ground_truth_input);
  551. // Merge all the summaries and write them out to the summaries_dir
  552. var merged = tf.summary.merge_all();
  553. var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph);
  554. var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph);
  555. // Create a train saver that is used to restore values into an eval graph
  556. // when exporting models.
  557. var train_saver = tf.train.Saver();
  558. train_saver.save(sess, CHECKPOINT_NAME);
  559. sw.Restart();
  560. for (int i = 0; i < how_many_training_steps; i++)
  561. {
  562. var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks(
  563. sess, image_lists, train_batch_size, "training",
  564. bottleneck_dir, image_dir, jpeg_data_tensor,
  565. decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
  566. tfhub_module);
  567. // Feed the bottlenecks and ground truth into the graph, and run a training
  568. // step. Capture training summaries for TensorBoard with the `merged` op.
  569. var results = sess.run(
  570. new ITensorOrOperation[] { merged, train_step },
  571. new FeedItem(bottleneck_input, train_bottlenecks),
  572. new FeedItem(ground_truth_input, train_ground_truth));
  573. var train_summary = results[0];
  574. // TODO
  575. train_writer.add_summary(train_summary, i);
  576. // Every so often, print out how well the graph is training.
  577. bool is_last_step = (i + 1 == how_many_training_steps);
  578. if ((i % eval_step_interval) == 0 || is_last_step)
  579. {
  580. results = sess.run(
  581. new Tensor[] { evaluation_step, cross_entropy },
  582. new FeedItem(bottleneck_input, train_bottlenecks),
  583. new FeedItem(ground_truth_input, train_ground_truth));
  584. (float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
  585. print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}");
  586. var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
  587. sess, image_lists, validation_batch_size, "validation",
  588. bottleneck_dir, image_dir, jpeg_data_tensor,
  589. decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
  590. tfhub_module);
  591. // Run a validation step and capture training summaries for TensorBoard
  592. // with the `merged` op.
  593. results = sess.run(new Tensor[] { merged, evaluation_step },
  594. new FeedItem(bottleneck_input, validation_bottlenecks),
  595. new FeedItem(ground_truth_input, validation_ground_truth));
  596. (string validation_summary, float validation_accuracy) = (results[0], results[1]);
  597. validation_writer.add_summary(validation_summary, i);
  598. print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms");
  599. sw.Restart();
  600. }
  601. // Store intermediate results
  602. int intermediate_frequency = intermediate_store_frequency;
  603. if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0)
  604. {
  605. }
  606. }
  607. // After training is complete, force one last save of the train checkpoint.
  608. train_saver.save(sess, CHECKPOINT_NAME);
  609. // We've completed all our training, so run a final test evaluation on
  610. // some new images we haven't used before.
  611. (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists,
  612. jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
  613. bottleneck_tensor);
  614. // Write out the trained graph and labels with the weights stored as
  615. // constants.
  616. print($"Save final result to : {output_graph}");
  617. save_graph_to_file(output_graph, class_count);
  618. File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys));
  619. }
  620. public void Predict(Session sess)
  621. {
  622. throw new NotImplementedException();
  623. }
  624. public void Test(Session sess)
  625. {
  626. throw new NotImplementedException();
  627. }
  628. }
  629. }