You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

YOLOv3.cs 3.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  8. {
  9. public class YOLOv3
  10. {
  11. Config cfg;
  12. Tensor trainable;
  13. Tensor input_data;
  14. Dictionary<int, string> classes;
  15. int num_class;
  16. NDArray strides;
  17. NDArray anchors;
  18. int anchor_per_scale;
  19. float iou_loss_thresh;
  20. string upsample_method;
  21. Tensor conv_lbbox;
  22. Tensor conv_mbbox;
  23. Tensor conv_sbbox;
  24. public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_)
  25. {
  26. cfg = cfg_;
  27. input_data = input_data_;
  28. trainable = trainable_;
  29. classes = Utils.read_class_names(cfg.YOLO.CLASSES);
  30. num_class = len(classes);
  31. strides = np.array(cfg.YOLO.STRIDES);
  32. anchors = Utils.get_anchors(cfg.YOLO.ANCHORS);
  33. anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE;
  34. iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH;
  35. upsample_method = cfg.YOLO.UPSAMPLE_METHOD;
  36. (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data);
  37. tf_with(tf.variable_scope("pred_sbbox"), scope =>
  38. {
  39. // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
  40. });
  41. tf_with(tf.variable_scope("pred_mbbox"), scope =>
  42. {
  43. // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
  44. });
  45. tf_with(tf.variable_scope("pred_lbbox"), scope =>
  46. {
  47. // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]);
  48. });
  49. }
  50. private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data)
  51. {
  52. Tensor route_1, route_2;
  53. (route_1, route_2, input_data) = backbone.darknet53(input_data, trainable);
  54. input_data = common.convolutional(input_data, new[] { 1, 1, 1024, 512 }, trainable, "conv52");
  55. input_data = common.convolutional(input_data, new[] { 3, 3, 512, 1024 }, trainable, "conv53");
  56. input_data = common.convolutional(input_data, new[] { 1, 1, 1024, 512 }, trainable, "conv54");
  57. input_data = common.convolutional(input_data, new[] { 3, 3, 512, 1024 }, trainable, "conv55");
  58. input_data = common.convolutional(input_data, new[] { 1, 1, 1024, 512 }, trainable, "conv56");
  59. var conv_lobj_branch = common.convolutional(input_data, new[] { 3, 3, 512, 1024 }, trainable, name: "conv_lobj_branch");
  60. var conv_lbbox = common.convolutional(conv_lobj_branch, new[] { 1, 1, 1024, 3 * (num_class + 5) },
  61. trainable: trainable, name: "conv_lbbox", activate: false, bn: false);
  62. input_data = common.convolutional(input_data, new[] { 1, 1, 512, 256 }, trainable, "conv57");
  63. input_data = common.upsample(input_data, name: "upsample0", method: upsample_method);
  64. return (conv_lbbox, conv_mbbox, conv_sbbox);
  65. }
  66. }
  67. }