You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Main.cs 9.5 kB


  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  9. {
  10. /// <summary>
  11. /// Implementation of YOLO v3 object detector in Tensorflow
  12. /// https://github.com/YunYang1994/tensorflow-yolov3
  13. /// </summary>
  14. public class Main : IExample
  15. {
  16. public bool Enabled { get; set; } = true;
  17. public bool IsImportingGraph { get; set; } = false;
  18. public string Name => "YOLOv3";
  19. #region args
  20. Dictionary<int, string> classes;
  21. int num_classes;
  22. float learn_rate_init;
  23. float learn_rate_end;
  24. int first_stage_epochs;
  25. int second_stage_epochs;
  26. int warmup_periods;
  27. string time;
  28. float moving_ave_decay;
  29. int max_bbox_per_scale;
  30. int steps_per_period;
  31. Dataset trainset, testset;
  32. Config cfg;
  33. Tensor input_data;
  34. Tensor label_sbbox;
  35. Tensor label_mbbox;
  36. Tensor label_lbbox;
  37. Tensor true_sbboxes;
  38. Tensor true_mbboxes;
  39. Tensor true_lbboxes;
  40. Tensor trainable;
  41. Session sess;
  42. YOLOv3 model;
  43. VariableV1[] net_var;
  44. Tensor giou_loss, conf_loss, prob_loss;
  45. RefVariable global_step;
  46. Tensor learn_rate;
  47. Tensor loss;
  48. List<RefVariable> first_stage_trainable_var_list;
  49. Operation train_op_with_frozen_variables;
  50. Operation train_op_with_all_variables;
  51. Saver loader;
  52. Saver saver;
  53. #endregion
  54. public bool Run()
  55. {
  56. PrepareData();
  57. var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
  58. var options = new SessionOptions();
  59. options.SetConfig(new ConfigProto { AllowSoftPlacement = true });
  60. using (var sess = tf.Session(graph, opts: options))
  61. {
  62. Train(sess);
  63. }
  64. return true;
  65. }
  66. public void Train(Session sess)
  67. {
  68. sess.run(tf.global_variables_initializer());
  69. print($"=> Restoring weights from: {cfg.TRAIN.INITIAL_WEIGHT} ... ");
  70. loader.restore(sess, cfg.TRAIN.INITIAL_WEIGHT);
  71. }
  72. public void Test(Session sess)
  73. {
  74. throw new NotImplementedException();
  75. }
  76. public Graph BuildGraph()
  77. {
  78. var graph = new Graph().as_default();
  79. tf_with(tf.name_scope("define_input"), scope =>
  80. {
  81. input_data = tf.placeholder(dtype: tf.float32, name: "input_data");
  82. label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox");
  83. label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox");
  84. label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox");
  85. true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes");
  86. true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes");
  87. true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes");
  88. trainable = tf.placeholder(dtype: tf.@bool, name: "training");
  89. });
  90. tf_with(tf.name_scope("define_loss"), scope =>
  91. {
  92. model = new YOLOv3(cfg, input_data, trainable);
  93. net_var = tf.global_variables();
  94. (giou_loss, conf_loss, prob_loss) = model.compute_loss(
  95. label_sbbox, label_mbbox, label_lbbox,
  96. true_sbboxes, true_mbboxes, true_lbboxes);
  97. loss = giou_loss + conf_loss + prob_loss;
  98. });
  99. Tensor global_step_update = null;
  100. tf_with(tf.name_scope("learn_rate"), scope =>
  101. {
  102. global_step = tf.Variable(1.0, dtype: tf.float64, trainable: false, name: "global_step");
  103. var warmup_steps = tf.constant(warmup_periods * steps_per_period,
  104. dtype: tf.float64, name: "warmup_steps");
  105. var train_steps = tf.constant((first_stage_epochs + second_stage_epochs) * steps_per_period,
  106. dtype: tf.float64, name: "train_steps");
  107. learn_rate = tf.cond(
  108. pred: global_step < warmup_steps,
  109. true_fn: delegate
  110. {
  111. return global_step / warmup_steps * learn_rate_init;
  112. },
  113. false_fn: delegate
  114. {
  115. return learn_rate_end + 0.5 * (learn_rate_init - learn_rate_end) *
  116. (1 + tf.cos(
  117. (global_step - warmup_steps) / (train_steps - warmup_steps) * Math.PI));
  118. }
  119. );
  120. global_step_update = tf.assign_add(global_step, 1.0f);
  121. });
  122. Operation moving_ave = null;
  123. tf_with(tf.name_scope("define_weight_decay"), scope =>
  124. {
  125. var emv = tf.train.ExponentialMovingAverage(moving_ave_decay);
  126. var vars = tf.trainable_variables().Select(x => (RefVariable)x).ToArray();
  127. moving_ave = emv.apply(vars);
  128. });
  129. tf_with(tf.name_scope("define_first_stage_train"), scope =>
  130. {
  131. first_stage_trainable_var_list = new List<RefVariable>();
  132. foreach (var var in tf.trainable_variables())
  133. {
  134. var var_name = var.op.name;
  135. var var_name_mess = var_name.Split('/');
  136. if (new[] { "conv_sbbox", "conv_mbbox", "conv_lbbox" }.Contains(var_name_mess[0]))
  137. first_stage_trainable_var_list.Add(var as RefVariable);
  138. }
  139. var adam = tf.train.AdamOptimizer(learn_rate);
  140. var first_stage_optimizer = adam.minimize(loss, var_list: first_stage_trainable_var_list);
  141. tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate
  142. {
  143. tf_with(tf.control_dependencies(new ITensorOrOperation[] { first_stage_optimizer, global_step_update }), delegate
  144. {
  145. tf_with(tf.control_dependencies(new[] { moving_ave }), delegate
  146. {
  147. train_op_with_frozen_variables = tf.no_op();
  148. });
  149. });
  150. });
  151. });
  152. tf_with(tf.name_scope("define_second_stage_train"), delegate
  153. {
  154. var second_stage_trainable_var_list = tf.trainable_variables().Select(x => x as RefVariable).ToList();
  155. var adam = tf.train.AdamOptimizer(learn_rate);
  156. var second_stage_optimizer = adam.minimize(loss, var_list: second_stage_trainable_var_list);
  157. tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate
  158. {
  159. tf_with(tf.control_dependencies(new ITensorOrOperation[] { second_stage_optimizer, global_step_update }), delegate
  160. {
  161. tf_with(tf.control_dependencies(new[] { moving_ave }), delegate
  162. {
  163. train_op_with_all_variables = tf.no_op();
  164. });
  165. });
  166. });
  167. });
  168. tf_with(tf.name_scope("loader_and_saver"), delegate
  169. {
  170. loader = tf.train.Saver(net_var);
  171. saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10);
  172. });
  173. tf_with(tf.name_scope("summary"), delegate
  174. {
  175. tf.summary.scalar("learn_rate", learn_rate);
  176. tf.summary.scalar("giou_loss", giou_loss);
  177. tf.summary.scalar("conf_loss", conf_loss);
  178. tf.summary.scalar("prob_loss", prob_loss);
  179. tf.summary.scalar("total_loss", loss);
  180. });
  181. return graph;
  182. }
  183. public Graph ImportGraph()
  184. {
  185. throw new NotImplementedException();
  186. }
  187. public void Predict(Session sess)
  188. {
  189. throw new NotImplementedException();
  190. }
  191. public void PrepareData()
  192. {
  193. cfg = new Config(Name);
  194. string dataDir = Path.Combine(Name, "data");
  195. Directory.CreateDirectory(dataDir);
  196. classes = Utils.read_class_names(cfg.YOLO.CLASSES);
  197. num_classes = classes.Count;
  198. learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT;
  199. learn_rate_end = cfg.TRAIN.LEARN_RATE_END;
  200. first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS;
  201. second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS;
  202. warmup_periods = cfg.TRAIN.WARMUP_EPOCHS;
  203. DateTime now = DateTime.Now;
  204. time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}";
  205. moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY;
  206. max_bbox_per_scale = 150;
  207. trainset = new Dataset("train", cfg);
  208. testset = new Dataset("test", cfg);
  209. steps_per_period = trainset.Length;
  210. }
  211. }
  212. }