You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Datasets.cs 1.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. using NumSharp;
  2. namespace TensorFlowNET.Examples.Utility
  3. {
  4. public class Datasets<T> where T : IDataSet
  5. {
  6. private T _train;
  7. public T train => _train;
  8. private T _validation;
  9. public T validation => _validation;
  10. private T _test;
  11. public T test => _test;
  12. public Datasets(T train, T validation, T test)
  13. {
  14. _train = train;
  15. _validation = validation;
  16. _test = test;
  17. }
  18. public (NDArray, NDArray) Randomize(NDArray x, NDArray y)
  19. {
  20. var perm = np.random.permutation(y.shape[0]);
  21. np.random.shuffle(perm);
  22. return (x[perm], y[perm]);
  23. }
  24. /// <summary>
  25. /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
  26. /// </summary>
  27. /// <param name="x"></param>
  28. /// <param name="y"></param>
  29. /// <param name="start"></param>
  30. /// <param name="end"></param>
  31. /// <returns></returns>
  32. public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
  33. {
  34. var x_batch = x[$"{start}:{end}"];
  35. var y_batch = y[$"{start}:{end}"];
  36. return (x_batch, y_batch);
  37. }
  38. }
  39. }