You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Main.cs 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Text;
  5. using Tensorflow;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  8. {
  9. /// <summary>
  10. /// Implementation of YOLO v3 object detector in Tensorflow
  11. /// https://github.com/YunYang1994/tensorflow-yolov3
  12. /// </summary>
  13. public class Main : IExample
  14. {
  15. public bool Enabled { get; set; } = true;
  16. public bool IsImportingGraph { get; set; } = false;
  17. public string Name => "YOLOv3";
  18. #region args
  19. Dictionary<int, string> classes;
  20. int num_classes;
  21. float learn_rate_init;
  22. float learn_rate_end;
  23. int first_stage_epochs;
  24. int second_stage_epochs;
  25. int warmup_periods;
  26. string time;
  27. float moving_ave_decay;
  28. int max_bbox_per_scale;
  29. int steps_per_period;
  30. Dataset trainset, testset;
  31. Config cfg;
  32. Tensor input_data;
  33. Tensor label_sbbox;
  34. Tensor label_mbbox;
  35. Tensor label_lbbox;
  36. Tensor true_sbboxes;
  37. Tensor true_mbboxes;
  38. Tensor true_lbboxes;
  39. Tensor trainable;
  40. Session sess;
  41. YOLOv3 model;
  42. #endregion
  43. public bool Run()
  44. {
  45. PrepareData();
  46. var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
  47. var options = new SessionOptions();
  48. options.SetConfig(new ConfigProto { AllowSoftPlacement = true });
  49. using (var sess = tf.Session(graph, opts: options))
  50. {
  51. Train(sess);
  52. }
  53. return true;
  54. }
  55. public void Train(Session sess)
  56. {
  57. }
  58. public void Test(Session sess)
  59. {
  60. throw new NotImplementedException();
  61. }
  62. public Graph BuildGraph()
  63. {
  64. var graph = new Graph().as_default();
  65. tf_with(tf.name_scope("define_input"), scope =>
  66. {
  67. input_data = tf.placeholder(dtype: tf.float32, name: "input_data");
  68. label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox");
  69. label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox");
  70. label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox");
  71. true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes");
  72. true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes");
  73. true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes");
  74. trainable = tf.placeholder(dtype: tf.@bool, name: "training");
  75. });
  76. tf_with(tf.name_scope("define_loss"), scope =>
  77. {
  78. model = new YOLOv3(cfg, input_data, trainable);
  79. });
  80. tf_with(tf.name_scope("define_weight_decay"), scope =>
  81. {
  82. var moving_ave = tf.train.ExponentialMovingAverage(moving_ave_decay).apply((RefVariable[])tf.trainable_variables());
  83. });
  84. return graph;
  85. }
  86. public Graph ImportGraph()
  87. {
  88. throw new NotImplementedException();
  89. }
  90. public void Predict(Session sess)
  91. {
  92. throw new NotImplementedException();
  93. }
  94. public void PrepareData()
  95. {
  96. cfg = new Config(Name);
  97. string dataDir = Path.Combine(Name, "data");
  98. Directory.CreateDirectory(dataDir);
  99. classes = Utils.read_class_names(cfg.YOLO.CLASSES);
  100. num_classes = classes.Count;
  101. learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT;
  102. learn_rate_end = cfg.TRAIN.LEARN_RATE_END;
  103. first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS;
  104. second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS;
  105. warmup_periods = cfg.TRAIN.WARMUP_EPOCHS;
  106. DateTime now = DateTime.Now;
  107. time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}";
  108. moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY;
  109. max_bbox_per_scale = 150;
  110. trainset = new Dataset("train", cfg);
  111. testset = new Dataset("test", cfg);
  112. steps_per_period = trainset.Length;
  113. }
  114. }
  115. }