You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataSet.cs 3.2 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. namespace TensorFlowNET.Examples.Utility
  7. {
  8. public class DataSet
  9. {
  10. private int _num_examples;
  11. public int num_examples => _num_examples;
  12. private int _epochs_completed;
  13. public int epochs_completed => _epochs_completed;
  14. private int _index_in_epoch;
  15. public int index_in_epoch => _index_in_epoch;
  16. private NDArray _images;
  17. public NDArray images => _images;
  18. private NDArray _labels;
  19. public NDArray labels => _labels;
  20. public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
  21. {
  22. _num_examples = images.shape[0];
  23. images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
  24. images.astype(dtype.as_numpy_datatype());
  25. images = np.multiply(images, 1.0f / 255.0f);
  26. labels.astype(dtype.as_numpy_datatype());
  27. _images = images;
  28. _labels = labels;
  29. _epochs_completed = 0;
  30. _index_in_epoch = 0;
  31. }
  32. public (NDArray, NDArray) next_batch(int batch_size, bool fake_data = false, bool shuffle = true)
  33. {
  34. var start = _index_in_epoch;
  35. // Shuffle for the first epoch
  36. if(_epochs_completed == 0 && start == 0 && shuffle)
  37. {
  38. var perm0 = np.arange(_num_examples);
  39. np.random.shuffle(perm0);
  40. _images = images[perm0];
  41. _labels = labels[perm0];
  42. }
  43. // Go to the next epoch
  44. if (start + batch_size > _num_examples)
  45. {
  46. // Finished epoch
  47. _epochs_completed += 1;
  48. // Get the rest examples in this epoch
  49. var rest_num_examples = _num_examples - start;
  50. //var images_rest_part = _images[np.arange(start, _num_examples)];
  51. //var labels_rest_part = _labels[np.arange(start, _num_examples)];
  52. // Shuffle the data
  53. if (shuffle)
  54. {
  55. var perm = np.arange(_num_examples);
  56. np.random.shuffle(perm);
  57. _images = images[perm];
  58. _labels = labels[perm];
  59. }
  60. start = 0;
  61. _index_in_epoch = batch_size - rest_num_examples;
  62. var end = _index_in_epoch;
  63. var images_new_part = _images[np.arange(start, end)];
  64. var labels_new_part = _labels[np.arange(start, end)];
  65. /*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0),
  66. np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/
  67. return (images_new_part, labels_new_part);
  68. }
  69. else
  70. {
  71. _index_in_epoch += batch_size;
  72. var end = _index_in_epoch;
  73. return (_images[np.arange(start, end)], _labels[np.arange(start, end)]);
  74. }
  75. }
  76. }
  77. }