You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Datasets.cs 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. namespace TensorFlowNET.Examples.Utility
  6. {
  7. public class Datasets<T> where T : IDataSet
  8. {
  9. private T _train;
  10. public T train => _train;
  11. private T _validation;
  12. public T validation => _validation;
  13. private T _test;
  14. public T test => _test;
  15. public Datasets(T train, T validation, T test)
  16. {
  17. _train = train;
  18. _validation = validation;
  19. _test = test;
  20. }
  21. public (NDArray, NDArray) Randomize(NDArray x, NDArray y)
  22. {
  23. var perm = np.random.permutation(y.shape[0]);
  24. np.random.shuffle(perm);
  25. return (train.data[perm], train.labels[perm]);
  26. }
  27. /// <summary>
  28. /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
  29. /// </summary>
  30. /// <param name="x"></param>
  31. /// <param name="y"></param>
  32. /// <param name="start"></param>
  33. /// <param name="end"></param>
  34. /// <returns></returns>
  35. public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
  36. {
  37. var x_batch = x[$"{start}:{end}"];
  38. var y_batch = y[$"{start}:{end}"];
  39. return (x_batch, y_batch);
  40. }
  41. }
  42. }