You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Predict.cs 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Keras.ArgsDefinition;
  5. using Tensorflow.Keras.Engine.DataAdapters;
  6. using static Tensorflow.Binding;
  7. using Tensorflow.Keras.Callbacks;
  8. namespace Tensorflow.Keras.Engine
  9. {
  10. public partial class Model
  11. {
  12. public Tensors predict(IDatasetV2 dataset,
  13. int batch_size = -1,
  14. int verbose = 0,
  15. int steps = -1,
  16. int max_queue_size = 10,
  17. int workers = 1,
  18. bool use_multiprocessing = false)
  19. {
  20. var data_handler = new DataHandler(new DataHandlerArgs
  21. {
  22. Dataset = dataset,
  23. BatchSize = batch_size,
  24. StepsPerEpoch = steps,
  25. InitialEpoch = 0,
  26. Epochs = 1,
  27. MaxQueueSize = max_queue_size,
  28. Workers = workers,
  29. UseMultiprocessing = use_multiprocessing,
  30. Model = this,
  31. StepsPerExecution = _steps_per_execution
  32. });
  33. return PredictInternal(data_handler, verbose);
  34. }
  35. /// <summary>
  36. /// Generates output predictions for the input samples.
  37. /// </summary>
  38. /// <param name="x">Input samples</param>
  39. /// <param name="batch_size">Number of samples per batch</param>
  40. /// <param name="verbose">Verbosity mode</param>
  41. /// <param name="steps">
  42. /// Total number of steps (batches of samples)
  43. /// before declaring the prediction round finished.
  44. /// </param>
  45. /// <param name="max_queue_size"></param>
  46. /// <param name="workers"></param>
  47. /// <param name="use_multiprocessing"></param>
  48. /// <returns></returns>
  49. public Tensors predict(Tensors x,
  50. int batch_size = -1,
  51. int verbose = 0,
  52. int steps = -1,
  53. int max_queue_size = 10,
  54. int workers = 1,
  55. bool use_multiprocessing = false)
  56. {
  57. var data_handler = new DataHandler(new DataHandlerArgs
  58. {
  59. X = x,
  60. BatchSize = batch_size,
  61. StepsPerEpoch = steps,
  62. InitialEpoch = 0,
  63. Epochs = 1,
  64. MaxQueueSize = max_queue_size,
  65. Workers = workers,
  66. UseMultiprocessing = use_multiprocessing,
  67. Model = this,
  68. StepsPerExecution = _steps_per_execution
  69. });
  70. return PredictInternal(data_handler, verbose);
  71. }
  72. Tensors PredictInternal(DataHandler data_handler, int verbose)
  73. {
  74. var callbacks = new CallbackList(new CallbackParams
  75. {
  76. Model = this,
  77. Verbose = verbose,
  78. Epochs = 1,
  79. Steps = data_handler.Inferredsteps
  80. });
  81. Tensors batch_outputs = null;
  82. _predict_counter.assign(0);
  83. callbacks.on_predict_begin();
  84. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  85. {
  86. foreach (var step in data_handler.steps())
  87. {
  88. callbacks.on_predict_batch_begin(step);
  89. var tmp_batch_outputs = run_predict_step(iterator);
  90. if (batch_outputs == null)
  91. {
  92. batch_outputs = tmp_batch_outputs;
  93. }
  94. else
  95. {
  96. for (int i = 0; i < batch_outputs.Length; i++)
  97. batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0);
  98. }
  99. var end_step = step + data_handler.StepIncrement;
  100. callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
  101. GC.Collect();
  102. }
  103. }
  104. callbacks.on_predict_end();
  105. return batch_outputs;
  106. }
  107. Tensors run_predict_step(OwnedIterator iterator)
  108. {
  109. var data = iterator.next();
  110. var outputs = predict_step(data);
  111. tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _predict_counter.assign_add(1));
  112. return outputs;
  113. }
  114. Tensors predict_step(Tensors data)
  115. {
  116. return Apply(data, training: false);
  117. }
  118. }
  119. }