You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Evaluate.cs 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Callbacks;
  7. using Tensorflow.Keras.Engine.DataAdapters;
  8. using Tensorflow.Keras.Layers;
  9. using Tensorflow.Keras.Utils;
  10. using Tensorflow.NumPy;
  11. using static Tensorflow.Binding;
  12. namespace Tensorflow.Keras.Engine
  13. {
  14. public partial class Model
  15. {
  16. /// <summary>
  17. /// Returns the loss value and metrics values for the model in test mode.
  18. /// </summary>
  19. /// <param name="x"></param>
  20. /// <param name="y"></param>
  21. /// <param name="batch_size"></param>
  22. /// <param name="verbose"></param>
  23. /// <param name="steps"></param>
  24. /// <param name="max_queue_size"></param>
  25. /// <param name="workers"></param>
  26. /// <param name="use_multiprocessing"></param>
  27. /// <param name="return_dict"></param>
  28. /// <param name="is_val"></param>
  29. public Dictionary<string, float> evaluate(NDArray x, NDArray y,
  30. int batch_size = -1,
  31. int verbose = 1,
  32. int steps = -1,
  33. int max_queue_size = 10,
  34. int workers = 1,
  35. bool use_multiprocessing = false,
  36. bool return_dict = false,
  37. bool is_val = false
  38. )
  39. {
  40. if (x.dims[0] != y.dims[0])
  41. {
  42. throw new InvalidArgumentError(
  43. $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
  44. }
  45. var data_handler = new DataHandler(new DataHandlerArgs
  46. {
  47. X = x,
  48. Y = y,
  49. BatchSize = batch_size,
  50. StepsPerEpoch = steps,
  51. InitialEpoch = 0,
  52. Epochs = 1,
  53. MaxQueueSize = max_queue_size,
  54. Workers = workers,
  55. UseMultiprocessing = use_multiprocessing,
  56. Model = this,
  57. StepsPerExecution = _steps_per_execution
  58. });
  59. var callbacks = new CallbackList(new CallbackParams
  60. {
  61. Model = this,
  62. Verbose = verbose,
  63. Steps = data_handler.Inferredsteps
  64. });
  65. return evaluate(data_handler, callbacks, is_val, test_function);
  66. }
  67. public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
  68. {
  69. var data_handler = new DataHandler(new DataHandlerArgs
  70. {
  71. X = new Tensors(x.ToArray()),
  72. Y = y,
  73. Model = this,
  74. StepsPerExecution = _steps_per_execution
  75. });
  76. var callbacks = new CallbackList(new CallbackParams
  77. {
  78. Model = this,
  79. Verbose = verbose,
  80. Steps = data_handler.Inferredsteps
  81. });
  82. return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function);
  83. }
  84. public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
  85. {
  86. var data_handler = new DataHandler(new DataHandlerArgs
  87. {
  88. Dataset = x,
  89. Model = this,
  90. StepsPerExecution = _steps_per_execution
  91. });
  92. var callbacks = new CallbackList(new CallbackParams
  93. {
  94. Model = this,
  95. Verbose = verbose,
  96. Steps = data_handler.Inferredsteps
  97. });
  98. return evaluate(data_handler, callbacks, is_val, test_function);
  99. }
  100. /// <summary>
  101. /// Internal bare implementation of evaluate function.
  102. /// </summary>
  103. /// <param name="data_handler">Interations handling objects</param>
  104. /// <param name="callbacks"></param>
  105. /// <param name="test_func">The function to be called on each batch of data.</param>
  106. /// <param name="is_val">Whether it is validation or test.</param>
  107. /// <returns></returns>
  108. Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, OwnedIterator, Dictionary<string, float>> test_func)
  109. {
  110. callbacks.on_test_begin();
  111. var logs = new Dictionary<string, float>();
  112. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  113. {
  114. reset_metrics();
  115. foreach (var step in data_handler.steps())
  116. {
  117. callbacks.on_test_batch_begin(step);
  118. logs = test_func(data_handler, iterator);
  119. var end_step = step + data_handler.StepIncrement;
  120. if (!is_val)
  121. callbacks.on_test_batch_end(end_step, logs);
  122. }
  123. }
  124. callbacks.on_test_end(logs);
  125. var results = new Dictionary<string, float>(logs);
  126. return results;
  127. }
  128. Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
  129. {
  130. var data = iterator.next();
  131. var outputs = test_step(data_handler, data[0], data[1]);
  132. tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
  133. return outputs;
  134. }
  135. Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
  136. {
  137. var data = iterator.next();
  138. var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
  139. var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray());
  140. tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
  141. return outputs;
  142. }
  143. Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
  144. {
  145. (x, y) = data_handler.DataAdapter.Expand1d(x, y);
  146. var y_pred = Apply(x, training: false);
  147. var loss = compiled_loss.Call(y, y_pred);
  148. compiled_metrics.update_state(y, y_pred);
  149. return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
  150. }
  151. }
  152. }