You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataSetMnist.cs 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. namespace TensorFlowNET.Examples.Utility
  7. {
  8. public class DataSetMnist : IDataSet
  9. {
  10. public int num_examples { get; }
  11. public int epochs_completed { get; private set; }
  12. public int index_in_epoch { get; private set; }
  13. public NDArray data { get; private set; }
  14. public NDArray labels { get; private set; }
  15. public DataSetMnist(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
  16. {
  17. num_examples = images.shape[0];
  18. images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
  19. images.astype(dtype.as_numpy_datatype());
  20. images = np.multiply(images, 1.0f / 255.0f);
  21. labels.astype(dtype.as_numpy_datatype());
  22. data = images;
  23. this.labels = labels;
  24. epochs_completed = 0;
  25. index_in_epoch = 0;
  26. }
  27. public (NDArray, NDArray) next_batch(int batch_size, bool fake_data = false, bool shuffle = true)
  28. {
  29. var start = index_in_epoch;
  30. // Shuffle for the first epoch
  31. if(epochs_completed == 0 && start == 0 && shuffle)
  32. {
  33. var perm0 = np.arange(num_examples);
  34. np.random.shuffle(perm0);
  35. data = data[perm0];
  36. labels = labels[perm0];
  37. }
  38. // Go to the next epoch
  39. if (start + batch_size > num_examples)
  40. {
  41. // Finished epoch
  42. epochs_completed += 1;
  43. // Get the rest examples in this epoch
  44. var rest_num_examples = num_examples - start;
  45. //var images_rest_part = _images[np.arange(start, _num_examples)];
  46. //var labels_rest_part = _labels[np.arange(start, _num_examples)];
  47. // Shuffle the data
  48. if (shuffle)
  49. {
  50. var perm = np.arange(num_examples);
  51. np.random.shuffle(perm);
  52. data = data[perm];
  53. labels = labels[perm];
  54. }
  55. start = 0;
  56. index_in_epoch = batch_size - rest_num_examples;
  57. var end = index_in_epoch;
  58. var images_new_part = data[np.arange(start, end)];
  59. var labels_new_part = labels[np.arange(start, end)];
  60. /*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0),
  61. np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/
  62. return (images_new_part, labels_new_part);
  63. }
  64. else
  65. {
  66. index_in_epoch += batch_size;
  67. var end = index_in_epoch;
  68. return (data[np.arange(start, end)], labels[np.arange(start, end)]);
  69. }
  70. }
  71. }
  72. }