You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetUtils.index_directory.cs 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. namespace Tensorflow.Keras.Preprocessings
  7. {
  8. public partial class DatasetUtils
  9. {
  10. /// <summary>
  11. /// Make list of all files in the subdirs of `directory`, with their labels.
  12. /// </summary>
  13. /// <param name="directory"></param>
  14. /// <param name="labels"></param>
  15. /// <param name="formats"></param>
  16. /// <param name="class_names"></param>
  17. /// <param name="shuffle"></param>
  18. /// <param name="seed"></param>
  19. /// <param name="follow_links"></param>
  20. /// <returns>
  21. /// file_paths, labels, class_names
  22. /// </returns>
  23. public (string[], int[], string[]) index_directory(string directory,
  24. string labels,
  25. string[] formats = null,
  26. string[] class_names = null,
  27. bool shuffle = true,
  28. int? seed = null,
  29. bool follow_links = false)
  30. {
  31. var label_list = new List<int>();
  32. var file_paths = new List<string>();
  33. var class_dirs = Directory.GetDirectories(directory);
  34. class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar).Last()).ToArray();
  35. for (var label = 0; label < class_dirs.Length; label++)
  36. {
  37. var files = Directory.GetFiles(class_dirs[label]);
  38. file_paths.AddRange(files);
  39. label_list.AddRange(Enumerable.Range(0, files.Length).Select(x => label));
  40. }
  41. var return_labels = label_list.Select(x => x).ToArray();
  42. var return_file_paths = file_paths.Select(x => x).ToArray();
  43. if (shuffle)
  44. {
  45. if (!seed.HasValue)
  46. seed = np.random.randint((long)1e6);
  47. var random_index = np.arange(label_list.Count);
  48. var rng = np.random.RandomState(seed.Value);
  49. rng.shuffle(random_index);
  50. var index = random_index.ToArray<int>();
  51. for (int i = 0; i < label_list.Count; i++)
  52. {
  53. return_labels[i] = label_list[index[i]];
  54. return_file_paths[i] = file_paths[index[i]];
  55. }
  56. }
  57. Console.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes.");
  58. return (return_file_paths, return_labels, class_names);
  59. }
  60. }
  61. }