You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Main.cs 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  9. {
  10. /// <summary>
  11. /// Implementation of YOLO v3 object detector in Tensorflow
  12. /// https://github.com/YunYang1994/tensorflow-yolov3
  13. /// </summary>
  14. public class Main : IExample
  15. {
  16. public bool Enabled { get; set; } = true;
  17. public bool IsImportingGraph { get; set; } = false;
  18. public string Name => "YOLOv3";
  19. #region args
  20. Dictionary<int, string> classes;
  21. int num_classes;
  22. float learn_rate_init;
  23. float learn_rate_end;
  24. int first_stage_epochs;
  25. int second_stage_epochs;
  26. int warmup_periods;
  27. string time;
  28. float moving_ave_decay;
  29. int max_bbox_per_scale;
  30. int steps_per_period;
  31. Dataset trainset, testset;
  32. Config cfg;
  33. Tensor input_data;
  34. Tensor label_sbbox;
  35. Tensor label_mbbox;
  36. Tensor label_lbbox;
  37. Tensor true_sbboxes;
  38. Tensor true_mbboxes;
  39. Tensor true_lbboxes;
  40. Tensor trainable;
  41. Session sess;
  42. YOLOv3 model;
  43. VariableV1[] net_var;
  44. Tensor giou_loss, conf_loss, prob_loss;
  45. RefVariable global_step;
  46. Tensor learn_rate;
  47. Tensor loss;
  48. List<RefVariable> first_stage_trainable_var_list;
  49. Operation train_op_with_frozen_variables;
  50. Operation train_op_with_all_variables;
  51. #endregion
  52. public bool Run()
  53. {
  54. PrepareData();
  55. var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
  56. var options = new SessionOptions();
  57. options.SetConfig(new ConfigProto { AllowSoftPlacement = true });
  58. using (var sess = tf.Session(graph, opts: options))
  59. {
  60. Train(sess);
  61. }
  62. return true;
  63. }
  64. public void Train(Session sess)
  65. {
  66. }
  67. public void Test(Session sess)
  68. {
  69. throw new NotImplementedException();
  70. }
  71. public Graph BuildGraph()
  72. {
  73. var graph = new Graph().as_default();
  74. tf_with(tf.name_scope("define_input"), scope =>
  75. {
  76. input_data = tf.placeholder(dtype: tf.float32, name: "input_data");
  77. label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox");
  78. label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox");
  79. label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox");
  80. true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes");
  81. true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes");
  82. true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes");
  83. trainable = tf.placeholder(dtype: tf.@bool, name: "training");
  84. });
  85. tf_with(tf.name_scope("define_loss"), scope =>
  86. {
  87. model = new YOLOv3(cfg, input_data, trainable);
  88. net_var = tf.global_variables();
  89. (giou_loss, conf_loss, prob_loss) = model.compute_loss(
  90. label_sbbox, label_mbbox, label_lbbox,
  91. true_sbboxes, true_mbboxes, true_lbboxes);
  92. loss = giou_loss + conf_loss + prob_loss;
  93. });
  94. Tensor global_step_update = null;
  95. tf_with(tf.name_scope("learn_rate"), scope =>
  96. {
  97. global_step = tf.Variable(1.0, dtype: tf.float64, trainable: false, name: "global_step");
  98. var warmup_steps = tf.constant(warmup_periods * steps_per_period,
  99. dtype: tf.float64, name: "warmup_steps");
  100. var train_steps = tf.constant((first_stage_epochs + second_stage_epochs) * steps_per_period,
  101. dtype: tf.float64, name: "train_steps");
  102. learn_rate = tf.cond(
  103. pred: global_step < warmup_steps,
  104. true_fn: delegate
  105. {
  106. return global_step / warmup_steps * learn_rate_init;
  107. },
  108. false_fn: delegate
  109. {
  110. return learn_rate_end + 0.5 * (learn_rate_init - learn_rate_end) *
  111. (1 + tf.cos(
  112. (global_step - warmup_steps) / (train_steps - warmup_steps) * Math.PI));
  113. }
  114. );
  115. global_step_update = tf.assign_add(global_step, 1.0f);
  116. });
  117. Operation moving_ave = null;
  118. tf_with(tf.name_scope("define_weight_decay"), scope =>
  119. {
  120. var emv = tf.train.ExponentialMovingAverage(moving_ave_decay);
  121. var vars = tf.trainable_variables().Select(x => (RefVariable)x).ToArray();
  122. moving_ave = emv.apply(vars);
  123. });
  124. tf_with(tf.name_scope("define_first_stage_train"), scope =>
  125. {
  126. first_stage_trainable_var_list = new List<RefVariable>();
  127. foreach (var var in tf.trainable_variables())
  128. {
  129. var var_name = var.op.name;
  130. var var_name_mess = var_name.Split('/');
  131. if (new[] { "conv_sbbox", "conv_mbbox", "conv_lbbox" }.Contains(var_name_mess[0]))
  132. first_stage_trainable_var_list.Add(var as RefVariable);
  133. }
  134. var adam = tf.train.AdamOptimizer(learn_rate);
  135. var first_stage_optimizer = adam.minimize(loss, var_list: first_stage_trainable_var_list);
  136. tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate
  137. {
  138. tf_with(tf.control_dependencies(new ITensorOrOperation[] { first_stage_optimizer, global_step_update }), delegate
  139. {
  140. tf_with(tf.control_dependencies(new[] { moving_ave }), delegate
  141. {
  142. train_op_with_frozen_variables = tf.no_op();
  143. });
  144. });
  145. });
  146. });
  147. tf_with(tf.name_scope("define_second_stage_train"), delegate
  148. {
  149. var second_stage_trainable_var_list = tf.trainable_variables().Select(x => x as RefVariable).ToList();
  150. var adam = tf.train.AdamOptimizer(learn_rate);
  151. var second_stage_optimizer = adam.minimize(loss, var_list: second_stage_trainable_var_list);
  152. tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate
  153. {
  154. tf_with(tf.control_dependencies(new ITensorOrOperation[] { second_stage_optimizer, global_step_update }), delegate
  155. {
  156. tf_with(tf.control_dependencies(new[] { moving_ave }), delegate
  157. {
  158. train_op_with_all_variables = tf.no_op();
  159. });
  160. });
  161. });
  162. });
  163. return graph;
  164. }
  165. public Graph ImportGraph()
  166. {
  167. throw new NotImplementedException();
  168. }
  169. public void Predict(Session sess)
  170. {
  171. throw new NotImplementedException();
  172. }
  173. public void PrepareData()
  174. {
  175. cfg = new Config(Name);
  176. string dataDir = Path.Combine(Name, "data");
  177. Directory.CreateDirectory(dataDir);
  178. classes = Utils.read_class_names(cfg.YOLO.CLASSES);
  179. num_classes = classes.Count;
  180. learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT;
  181. learn_rate_end = cfg.TRAIN.LEARN_RATE_END;
  182. first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS;
  183. second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS;
  184. warmup_periods = cfg.TRAIN.WARMUP_EPOCHS;
  185. DateTime now = DateTime.Now;
  186. time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}";
  187. moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY;
  188. max_bbox_per_scale = 150;
  189. trainset = new Dataset("train", cfg);
  190. testset = new Dataset("test", cfg);
  191. steps_per_period = trainset.Length;
  192. }
  193. }
  194. }