You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Evaluate.cs 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Callbacks;
  7. using Tensorflow.Keras.Engine.DataAdapters;
  8. using Tensorflow.Keras.Layers;
  9. using Tensorflow.Keras.Utils;
  10. using Tensorflow.NumPy;
  11. using static Tensorflow.Binding;
  12. namespace Tensorflow.Keras.Engine
  13. {
  14. public partial class Model
  15. {
  16. /// <summary>
  17. /// Returns the loss value and metrics values for the model in test mode.
  18. /// </summary>
  19. /// <param name="x"></param>
  20. /// <param name="y"></param>
  21. /// <param name="batch_size"></param>
  22. /// <param name="verbose"></param>
  23. /// <param name="steps"></param>
  24. /// <param name="max_queue_size"></param>
  25. /// <param name="workers"></param>
  26. /// <param name="use_multiprocessing"></param>
  27. /// <param name="return_dict"></param>
  28. /// <param name="is_val"></param>
  29. public Dictionary<string, float> evaluate(NDArray x, NDArray y,
  30. int batch_size = -1,
  31. int verbose = 1,
  32. NDArray sample_weight = null,
  33. int steps = -1,
  34. int max_queue_size = 10,
  35. int workers = 1,
  36. bool use_multiprocessing = false,
  37. bool return_dict = false,
  38. bool is_val = false
  39. )
  40. {
  41. if (x.dims[0] != y.dims[0])
  42. {
  43. throw new InvalidArgumentError(
  44. $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
  45. }
  46. var data_handler = new DataHandler(new DataHandlerArgs
  47. {
  48. X = x,
  49. Y = y,
  50. BatchSize = batch_size,
  51. StepsPerEpoch = steps,
  52. InitialEpoch = 0,
  53. Epochs = 1,
  54. SampleWeight = sample_weight,
  55. MaxQueueSize = max_queue_size,
  56. Workers = workers,
  57. UseMultiprocessing = use_multiprocessing,
  58. Model = this,
  59. StepsPerExecution = _steps_per_execution
  60. });
  61. var callbacks = new CallbackList(new CallbackParams
  62. {
  63. Model = this,
  64. Verbose = verbose,
  65. Steps = data_handler.Inferredsteps
  66. });
  67. return evaluate(data_handler, callbacks, is_val, test_function);
  68. }
  69. public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
  70. {
  71. var data_handler = new DataHandler(new DataHandlerArgs
  72. {
  73. X = new Tensors(x.ToArray()),
  74. Y = y,
  75. Model = this,
  76. StepsPerExecution = _steps_per_execution
  77. });
  78. var callbacks = new CallbackList(new CallbackParams
  79. {
  80. Model = this,
  81. Verbose = verbose,
  82. Steps = data_handler.Inferredsteps
  83. });
  84. return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function);
  85. }
  86. public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
  87. {
  88. var data_handler = new DataHandler(new DataHandlerArgs
  89. {
  90. Dataset = x,
  91. Model = this,
  92. StepsPerExecution = _steps_per_execution
  93. });
  94. var callbacks = new CallbackList(new CallbackParams
  95. {
  96. Model = this,
  97. Verbose = verbose,
  98. Steps = data_handler.Inferredsteps
  99. });
  100. return evaluate(data_handler, callbacks, is_val, test_function);
  101. }
  102. /// <summary>
  103. /// Internal bare implementation of evaluate function.
  104. /// </summary>
  105. /// <param name="data_handler">Interations handling objects</param>
  106. /// <param name="callbacks"></param>
  107. /// <param name="test_func">The function to be called on each batch of data.</param>
  108. /// <param name="is_val">Whether it is validation or test.</param>
  109. /// <returns></returns>
  110. Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, OwnedIterator, Dictionary<string, float>> test_func)
  111. {
  112. callbacks.on_test_begin();
  113. var logs = new Dictionary<string, float>();
  114. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  115. {
  116. reset_metrics();
  117. foreach (var step in data_handler.steps())
  118. {
  119. callbacks.on_test_batch_begin(step);
  120. logs = test_func(data_handler, iterator);
  121. var end_step = step + data_handler.StepIncrement;
  122. if (!is_val)
  123. callbacks.on_test_batch_end(end_step, logs);
  124. GC.Collect();
  125. }
  126. }
  127. callbacks.on_test_end(logs);
  128. var results = new Dictionary<string, float>(logs);
  129. return results;
  130. }
  131. Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
  132. {
  133. var data = iterator.next();
  134. var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) :
  135. test_step(data_handler, data[0], data[1], data[2]);
  136. tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
  137. return outputs;
  138. }
  139. Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
  140. {
  141. var data = iterator.next();
  142. var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
  143. var outputs = data.Length == 2 ?
  144. test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) :
  145. test_step(
  146. data_handler,
  147. new Tensors(data.Take(x_size).ToArray()),
  148. new Tensors(data.Skip(x_size).Take(x_size).ToArray()),
  149. new Tensors(data.Skip(2 * x_size).ToArray()));
  150. tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
  151. return outputs;
  152. }
  153. Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
  154. {
  155. (x,y) = data_handler.DataAdapter.Expand1d(x, y);
  156. var y_pred = Apply(x, training: false);
  157. var loss = compiled_loss.Call(y, y_pred);
  158. compiled_metrics.update_state(y, y_pred);
  159. return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
  160. }
  161. Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight)
  162. {
  163. (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
  164. var y_pred = Apply(x, training: false);
  165. var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight);
  166. compiled_metrics.update_state(y, y_pred);
  167. return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
  168. }
  169. }
  170. }