You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Main.cs 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Text;
  5. using Tensorflow;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  8. {
  9. /// <summary>
  10. /// Implementation of YOLO v3 object detector in Tensorflow
  11. /// https://github.com/YunYang1994/tensorflow-yolov3
  12. /// </summary>
  13. public class Main : IExample
  14. {
  15. public bool Enabled { get; set; } = true;
  16. public bool IsImportingGraph { get; set; } = false;
  17. public string Name => "YOLOv3";
  18. Dictionary<int, string> classes;
  19. Config config;
  20. Tensor input_data;
  21. Tensor label_sbbox;
  22. Tensor label_mbbox;
  23. Tensor label_lbbox;
  24. Tensor true_sbboxes;
  25. Tensor true_mbboxes;
  26. Tensor true_lbboxes;
  27. Tensor trainable;
  28. public bool Run()
  29. {
  30. PrepareData();
  31. var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
  32. using (var sess = tf.Session(graph))
  33. {
  34. Train(sess);
  35. }
  36. return true;
  37. }
  38. public void Train(Session sess)
  39. {
  40. }
  41. public void Test(Session sess)
  42. {
  43. throw new NotImplementedException();
  44. }
  45. public Graph BuildGraph()
  46. {
  47. var graph = new Graph().as_default();
  48. tf_with(tf.name_scope("define_input"), scope =>
  49. {
  50. input_data = tf.placeholder(dtype: tf.float32, name: "input_data");
  51. label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox");
  52. label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox");
  53. label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox");
  54. true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes");
  55. true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes");
  56. true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes");
  57. trainable = tf.placeholder(dtype: tf.@bool, name: "training");
  58. });
  59. tf_with(tf.name_scope("define_loss"), scope =>
  60. {
  61. //model = new YOLOv3(input_data, trainable);
  62. });
  63. return graph;
  64. }
  65. public Graph ImportGraph()
  66. {
  67. throw new NotImplementedException();
  68. }
  69. public void Predict(Session sess)
  70. {
  71. throw new NotImplementedException();
  72. }
  73. public void PrepareData()
  74. {
  75. config = new Config(Name);
  76. string dataDir = Path.Combine(Name, "data");
  77. Directory.CreateDirectory(dataDir);
  78. classes = new Dictionary<int, string>();
  79. foreach (var line in File.ReadAllLines(config.CLASSES))
  80. classes[classes.Count] = line;
  81. }
  82. }
  83. }