You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataHandler.cs 4.6 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. using System;
  2. using System.Collections.Generic;
  3. using Tensorflow.Keras.ArgsDefinition;
  4. using static Tensorflow.Binding;
  5. namespace Tensorflow.Keras.Engine.DataAdapters
  6. {
  7. /// <summary>
  8. /// Handles iterating over epoch-level `tf.data.Iterator` objects.
  9. /// </summary>
  10. public class DataHandler
  11. {
  12. DataHandlerArgs args;
  13. IDataAdapter _adapter;
  14. public IDataAdapter DataAdapter => _adapter;
  15. IDatasetV2 _dataset;
  16. int _inferred_steps;
  17. int _current_step;
  18. int _step_increment;
  19. bool _insufficient_data;
  20. int _steps_per_execution_value;
  21. int _initial_epoch => args.InitialEpoch;
  22. int _epochs => args.Epochs;
  23. IVariableV1 _steps_per_execution;
  24. public DataHandler(DataHandlerArgs args)
  25. {
  26. this.args = args;
  27. if (args.StepsPerExecution == null)
  28. {
  29. _steps_per_execution = tf.Variable(1);
  30. _steps_per_execution_value = 1;
  31. }
  32. else
  33. {
  34. _steps_per_execution = args.StepsPerExecution;
  35. _steps_per_execution_value = args.StepsPerExecution.numpy();
  36. }
  37. if(args.Dataset == null)
  38. {
  39. _adapter = new TensorLikeDataAdapter(new DataAdapterArgs
  40. {
  41. X = args.X,
  42. Y = args.Y,
  43. BatchSize = args.BatchSize,
  44. Steps = args.StepsPerEpoch,
  45. Epochs = args.Epochs - args.InitialEpoch,
  46. Shuffle = args.Shuffle,
  47. MaxQueueSize = args.MaxQueueSize,
  48. Worker = args.Workers,
  49. UseMultiprocessing = args.UseMultiprocessing,
  50. Model = args.Model
  51. });
  52. }
  53. else
  54. {
  55. _adapter = new DatasetAdapter(new DataAdapterArgs
  56. {
  57. Dataset = args.Dataset,
  58. BatchSize = args.BatchSize,
  59. Steps = args.StepsPerEpoch,
  60. Epochs = args.Epochs - args.InitialEpoch,
  61. Shuffle = args.Shuffle,
  62. MaxQueueSize = args.MaxQueueSize,
  63. Worker = args.Workers,
  64. UseMultiprocessing = args.UseMultiprocessing,
  65. Model = args.Model
  66. });
  67. }
  68. _dataset = _adapter.GetDataset();
  69. _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
  70. _current_step = 0;
  71. _step_increment = args.StepsPerExecution.numpy() - 1;
  72. _insufficient_data = false;
  73. }
  74. int _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
  75. {
  76. if (steps_per_epoch > -1)
  77. return steps_per_epoch;
  78. var adapter_steps = _adapter.GetSize();
  79. if (adapter_steps > -1)
  80. return adapter_steps;
  81. var size = dataset.dataset_cardinality();
  82. return size.numpy();
  83. }
  84. public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
  85. {
  86. using var ownedIterator = new OwnedIterator(_dataset);
  87. foreach (var epoch in range(_initial_epoch, _epochs))
  88. {
  89. if (_insufficient_data)
  90. break;
  91. yield return (epoch, ownedIterator);
  92. }
  93. }
  94. public IEnumerable<int> steps()
  95. {
  96. _current_step = 0;
  97. while (_current_step < _inferred_steps)
  98. {
  99. if (_insufficient_data)
  100. break;
  101. bool can_run_full_execution = _steps_per_execution_value == 1
  102. || _inferred_steps < 0
  103. || _inferred_steps - _current_step >= _steps_per_execution_value;
  104. if (can_run_full_execution)
  105. {
  106. _step_increment = _steps_per_execution_value - 1;
  107. yield return _current_step;
  108. _current_step += _steps_per_execution_value;
  109. }
  110. else
  111. {
  112. var steps_remaining = _inferred_steps - _current_step;
  113. _steps_per_execution.assign(steps_remaining);
  114. _step_increment = steps_remaining - 1;
  115. yield return _current_step;
  116. _current_step += steps_remaining;
  117. _steps_per_execution.assign(_steps_per_execution_value);
  118. }
  119. }
  120. }
  121. }
  122. }