You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Main.cs 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  9. {
  10. /// <summary>
  11. /// Implementation of YOLO v3 object detector in Tensorflow
  12. /// https://github.com/YunYang1994/tensorflow-yolov3
  13. /// </summary>
  14. public class Main : IExample
  15. {
  16. public bool Enabled { get; set; } = true;
  17. public bool IsImportingGraph { get; set; } = false;
  18. public string Name => "YOLOv3";
  19. #region args
  20. Dictionary<int, string> classes;
  21. int num_classes;
  22. float learn_rate_init;
  23. float learn_rate_end;
  24. int first_stage_epochs;
  25. int second_stage_epochs;
  26. int warmup_periods;
  27. string time;
  28. float moving_ave_decay;
  29. int max_bbox_per_scale;
  30. int steps_per_period;
  31. Dataset trainset, testset;
  32. Config cfg;
  33. Tensor input_data;
  34. Tensor label_sbbox;
  35. Tensor label_mbbox;
  36. Tensor label_lbbox;
  37. Tensor true_sbboxes;
  38. Tensor true_mbboxes;
  39. Tensor true_lbboxes;
  40. Tensor trainable;
  41. Session sess;
  42. YOLOv3 model;
  43. VariableV1[] net_var;
  44. Tensor giou_loss, conf_loss, prob_loss;
  45. RefVariable global_step;
  46. Tensor learn_rate;
  47. Tensor loss;
  48. #endregion
  49. public bool Run()
  50. {
  51. PrepareData();
  52. var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
  53. var options = new SessionOptions();
  54. options.SetConfig(new ConfigProto { AllowSoftPlacement = true });
  55. using (var sess = tf.Session(graph, opts: options))
  56. {
  57. Train(sess);
  58. }
  59. return true;
  60. }
  61. public void Train(Session sess)
  62. {
  63. }
  64. public void Test(Session sess)
  65. {
  66. throw new NotImplementedException();
  67. }
  68. public Graph BuildGraph()
  69. {
  70. var graph = new Graph().as_default();
  71. tf_with(tf.name_scope("define_input"), scope =>
  72. {
  73. input_data = tf.placeholder(dtype: tf.float32, name: "input_data");
  74. label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox");
  75. label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox");
  76. label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox");
  77. true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes");
  78. true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes");
  79. true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes");
  80. trainable = tf.placeholder(dtype: tf.@bool, name: "training");
  81. });
  82. tf_with(tf.name_scope("define_loss"), scope =>
  83. {
  84. model = new YOLOv3(cfg, input_data, trainable);
  85. net_var = tf.global_variables();
  86. (giou_loss, conf_loss, prob_loss) = model.compute_loss(
  87. label_sbbox, label_mbbox, label_lbbox,
  88. true_sbboxes, true_mbboxes, true_lbboxes);
  89. loss = giou_loss + conf_loss + prob_loss;
  90. });
  91. Tensor global_step_update = null;
  92. tf_with(tf.name_scope("learn_rate"), scope =>
  93. {
  94. global_step = tf.Variable(1.0, dtype: tf.float64, trainable: false, name: "global_step");
  95. var warmup_steps = tf.constant(warmup_periods * steps_per_period,
  96. dtype: tf.float64, name: "warmup_steps");
  97. var train_steps = tf.constant((first_stage_epochs + second_stage_epochs) * steps_per_period,
  98. dtype: tf.float64, name: "train_steps");
  99. learn_rate = tf.cond(
  100. pred: global_step < warmup_steps,
  101. true_fn: delegate
  102. {
  103. return global_step / warmup_steps * learn_rate_init;
  104. },
  105. false_fn: delegate
  106. {
  107. return learn_rate_end + 0.5 * (learn_rate_init - learn_rate_end) *
  108. (1 + tf.cos(
  109. (global_step - warmup_steps) / (train_steps - warmup_steps) * Math.PI));
  110. }
  111. );
  112. global_step_update = tf.assign_add(global_step, 1.0f);
  113. });
  114. Operation moving_ave = null;
  115. tf_with(tf.name_scope("define_weight_decay"), scope =>
  116. {
  117. var emv = tf.train.ExponentialMovingAverage(moving_ave_decay);
  118. var vars = tf.trainable_variables().Select(x => (RefVariable)x).ToArray();
  119. moving_ave = emv.apply(vars);
  120. });
  121. tf_with(tf.name_scope("define_first_stage_train"), scope =>
  122. {
  123. });
  124. return graph;
  125. }
  126. public Graph ImportGraph()
  127. {
  128. throw new NotImplementedException();
  129. }
  130. public void Predict(Session sess)
  131. {
  132. throw new NotImplementedException();
  133. }
  134. public void PrepareData()
  135. {
  136. cfg = new Config(Name);
  137. string dataDir = Path.Combine(Name, "data");
  138. Directory.CreateDirectory(dataDir);
  139. classes = Utils.read_class_names(cfg.YOLO.CLASSES);
  140. num_classes = classes.Count;
  141. learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT;
  142. learn_rate_end = cfg.TRAIN.LEARN_RATE_END;
  143. first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS;
  144. second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS;
  145. warmup_periods = cfg.TRAIN.WARMUP_EPOCHS;
  146. DateTime now = DateTime.Now;
  147. time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}";
  148. moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY;
  149. max_bbox_per_scale = 150;
  150. trainset = new Dataset("train", cfg);
  151. testset = new Dataset("test", cfg);
  152. steps_per_period = trainset.Length;
  153. }
  154. }
  155. }