You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Main.cs 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  9. {
  10. /// <summary>
  11. /// Implementation of YOLO v3 object detector in Tensorflow
  12. /// https://github.com/YunYang1994/tensorflow-yolov3
  13. /// </summary>
  14. public class Main : IExample
  15. {
  16. public bool Enabled { get; set; } = true;
  17. public bool IsImportingGraph { get; set; } = false;
  18. public string Name => "YOLOv3";
  19. #region args
  20. Dictionary<int, string> classes;
  21. int num_classes;
  22. float learn_rate_init;
  23. float learn_rate_end;
  24. int first_stage_epochs;
  25. int second_stage_epochs;
  26. int warmup_periods;
  27. string time;
  28. float moving_ave_decay;
  29. int max_bbox_per_scale;
  30. int steps_per_period;
  31. Dataset trainset, testset;
  32. Config cfg;
  33. Tensor input_data;
  34. Tensor label_sbbox;
  35. Tensor label_mbbox;
  36. Tensor label_lbbox;
  37. Tensor true_sbboxes;
  38. Tensor true_mbboxes;
  39. Tensor true_lbboxes;
  40. Tensor trainable;
  41. Session sess;
  42. YOLOv3 model;
  43. VariableV1[] net_var;
  44. Tensor giou_loss, conf_loss, prob_loss;
  45. RefVariable global_step;
  46. Tensor learn_rate;
  47. Tensor loss;
  48. List<RefVariable> first_stage_trainable_var_list;
  49. #endregion
  50. public bool Run()
  51. {
  52. PrepareData();
  53. var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
  54. var options = new SessionOptions();
  55. options.SetConfig(new ConfigProto { AllowSoftPlacement = true });
  56. using (var sess = tf.Session(graph, opts: options))
  57. {
  58. Train(sess);
  59. }
  60. return true;
  61. }
  62. public void Train(Session sess)
  63. {
  64. }
  65. public void Test(Session sess)
  66. {
  67. throw new NotImplementedException();
  68. }
  69. public Graph BuildGraph()
  70. {
  71. var graph = new Graph().as_default();
  72. tf_with(tf.name_scope("define_input"), scope =>
  73. {
  74. input_data = tf.placeholder(dtype: tf.float32, name: "input_data");
  75. label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox");
  76. label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox");
  77. label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox");
  78. true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes");
  79. true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes");
  80. true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes");
  81. trainable = tf.placeholder(dtype: tf.@bool, name: "training");
  82. });
  83. tf_with(tf.name_scope("define_loss"), scope =>
  84. {
  85. model = new YOLOv3(cfg, input_data, trainable);
  86. net_var = tf.global_variables();
  87. (giou_loss, conf_loss, prob_loss) = model.compute_loss(
  88. label_sbbox, label_mbbox, label_lbbox,
  89. true_sbboxes, true_mbboxes, true_lbboxes);
  90. loss = giou_loss + conf_loss + prob_loss;
  91. });
  92. Tensor global_step_update = null;
  93. tf_with(tf.name_scope("learn_rate"), scope =>
  94. {
  95. global_step = tf.Variable(1.0, dtype: tf.float64, trainable: false, name: "global_step");
  96. var warmup_steps = tf.constant(warmup_periods * steps_per_period,
  97. dtype: tf.float64, name: "warmup_steps");
  98. var train_steps = tf.constant((first_stage_epochs + second_stage_epochs) * steps_per_period,
  99. dtype: tf.float64, name: "train_steps");
  100. learn_rate = tf.cond(
  101. pred: global_step < warmup_steps,
  102. true_fn: delegate
  103. {
  104. return global_step / warmup_steps * learn_rate_init;
  105. },
  106. false_fn: delegate
  107. {
  108. return learn_rate_end + 0.5 * (learn_rate_init - learn_rate_end) *
  109. (1 + tf.cos(
  110. (global_step - warmup_steps) / (train_steps - warmup_steps) * Math.PI));
  111. }
  112. );
  113. global_step_update = tf.assign_add(global_step, 1.0f);
  114. });
  115. Operation moving_ave = null;
  116. tf_with(tf.name_scope("define_weight_decay"), scope =>
  117. {
  118. var emv = tf.train.ExponentialMovingAverage(moving_ave_decay);
  119. var vars = tf.trainable_variables().Select(x => (RefVariable)x).ToArray();
  120. moving_ave = emv.apply(vars);
  121. });
  122. tf_with(tf.name_scope("define_first_stage_train"), scope =>
  123. {
  124. first_stage_trainable_var_list = new List<RefVariable>();
  125. foreach (var var in tf.trainable_variables())
  126. {
  127. var var_name = var.op.name;
  128. var var_name_mess = var_name.Split('/');
  129. if (new[] { "conv_sbbox", "conv_mbbox", "conv_lbbox" }.Contains(var_name_mess[0]))
  130. first_stage_trainable_var_list.Add(var as RefVariable);
  131. }
  132. var adam = tf.train.AdamOptimizer(learn_rate);
  133. var first_stage_optimizer = adam.minimize(loss, var_list: first_stage_trainable_var_list);
  134. });
  135. return graph;
  136. }
  137. public Graph ImportGraph()
  138. {
  139. throw new NotImplementedException();
  140. }
  141. public void Predict(Session sess)
  142. {
  143. throw new NotImplementedException();
  144. }
  145. public void PrepareData()
  146. {
  147. cfg = new Config(Name);
  148. string dataDir = Path.Combine(Name, "data");
  149. Directory.CreateDirectory(dataDir);
  150. classes = Utils.read_class_names(cfg.YOLO.CLASSES);
  151. num_classes = classes.Count;
  152. learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT;
  153. learn_rate_end = cfg.TRAIN.LEARN_RATE_END;
  154. first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS;
  155. second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS;
  156. warmup_periods = cfg.TRAIN.WARMUP_EPOCHS;
  157. DateTime now = DateTime.Now;
  158. time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}";
  159. moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY;
  160. max_bbox_per_scale = 150;
  161. trainset = new Dataset("train", cfg);
  162. testset = new Dataset("test", cfg);
  163. steps_per_period = trainset.Length;
  164. }
  165. }
  166. }