You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Dataset.cs 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Text;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  8. {
  9. public class Dataset
  10. {
  11. string annot_path;
  12. int[] input_sizes;
  13. int batch_size;
  14. bool data_aug;
  15. int[] train_input_sizes;
  16. NDArray strides;
  17. NDArray anchors;
  18. Dictionary<int, string> classes;
  19. int num_classes;
  20. int anchor_per_scale;
  21. int max_bbox_per_scale;
  22. string[] annotations;
  23. int num_samples;
  24. int batch_count;
  25. public int Length = 0;
  26. public Dataset(string dataset_type, Config cfg)
  27. {
  28. annot_path = dataset_type == "train" ? cfg.TRAIN.ANNOT_PATH : cfg.TEST.ANNOT_PATH;
  29. input_sizes = dataset_type == "train" ? cfg.TRAIN.INPUT_SIZE : cfg.TEST.INPUT_SIZE;
  30. batch_size = dataset_type == "train" ? cfg.TRAIN.BATCH_SIZE : cfg.TEST.BATCH_SIZE;
  31. data_aug = dataset_type == "train" ? cfg.TRAIN.DATA_AUG : cfg.TEST.DATA_AUG;
  32. train_input_sizes = cfg.TRAIN.INPUT_SIZE;
  33. strides = np.array(cfg.YOLO.STRIDES);
  34. classes = Utils.read_class_names(cfg.YOLO.CLASSES);
  35. num_classes = classes.Count;
  36. anchors = np.array(Utils.get_anchors(cfg.YOLO.ANCHORS));
  37. anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE;
  38. max_bbox_per_scale = 150;
  39. annotations = load_annotations();
  40. num_samples = len(annotations);
  41. batch_count = 0;
  42. }
  43. string[] load_annotations()
  44. {
  45. return File.ReadAllLines(annot_path);
  46. }
  47. }
  48. }