You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Main.cs 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Text;
  5. using Tensorflow;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  8. {
  9. /// <summary>
  10. /// Implementation of YOLO v3 object detector in Tensorflow
  11. /// https://github.com/YunYang1994/tensorflow-yolov3
  12. /// </summary>
  13. public class Main : IExample
  14. {
  15. public bool Enabled { get; set; } = false;
  16. public bool IsImportingGraph { get; set; } = false;
  17. public string Name => "YOLOv3";
  18. #region args
  19. Dictionary<int, string> classes;
  20. int num_classes;
  21. float learn_rate_init;
  22. float learn_rate_end;
  23. int first_stage_epochs;
  24. int second_stage_epochs;
  25. int warmup_periods;
  26. string time;
  27. float moving_ave_decay;
  28. int max_bbox_per_scale;
  29. int steps_per_period;
  30. Dataset trainset, testset;
  31. Config cfg;
  32. Tensor input_data;
  33. Tensor label_sbbox;
  34. Tensor label_mbbox;
  35. Tensor label_lbbox;
  36. Tensor true_sbboxes;
  37. Tensor true_mbboxes;
  38. Tensor true_lbboxes;
  39. Tensor trainable;
  40. #endregion
  41. public bool Run()
  42. {
  43. PrepareData();
  44. var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
  45. using (var sess = tf.Session(graph))
  46. {
  47. Train(sess);
  48. }
  49. return true;
  50. }
  51. public void Train(Session sess)
  52. {
  53. }
  54. public void Test(Session sess)
  55. {
  56. throw new NotImplementedException();
  57. }
  58. public Graph BuildGraph()
  59. {
  60. var graph = new Graph().as_default();
  61. tf_with(tf.name_scope("define_input"), scope =>
  62. {
  63. input_data = tf.placeholder(dtype: tf.float32, name: "input_data");
  64. label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox");
  65. label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox");
  66. label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox");
  67. true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes");
  68. true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes");
  69. true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes");
  70. trainable = tf.placeholder(dtype: tf.@bool, name: "training");
  71. });
  72. tf_with(tf.name_scope("define_loss"), scope =>
  73. {
  74. //model = new YOLOv3(input_data, trainable);
  75. });
  76. return graph;
  77. }
  78. public Graph ImportGraph()
  79. {
  80. throw new NotImplementedException();
  81. }
  82. public void Predict(Session sess)
  83. {
  84. throw new NotImplementedException();
  85. }
  86. public void PrepareData()
  87. {
  88. cfg = new Config(Name);
  89. string dataDir = Path.Combine(Name, "data");
  90. Directory.CreateDirectory(dataDir);
  91. classes = new Dictionary<int, string>();
  92. foreach (var line in File.ReadAllLines(cfg.YOLO.CLASSES))
  93. classes[classes.Count] = line;
  94. num_classes = classes.Count;
  95. learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT;
  96. learn_rate_end = cfg.TRAIN.LEARN_RATE_END;
  97. first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS;
  98. second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS;
  99. warmup_periods = cfg.TRAIN.WARMUP_EPOCHS;
  100. DateTime now = DateTime.Now;
  101. time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}";
  102. moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY;
  103. max_bbox_per_scale = 150;
  104. trainset = new Dataset("train", cfg);
  105. testset = new Dataset("test", cfg);
  106. steps_per_period = trainset.Length;
  107. }
  108. }
  109. }