You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Evaluate.cs 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. using Tensorflow.NumPy;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine.DataAdapters;
  7. using static Tensorflow.Binding;
  8. using Tensorflow.Keras.Layers;
  9. using Tensorflow.Keras.Utils;
  10. using Tensorflow;
  11. using Tensorflow.Keras.Callbacks;
  12. namespace Tensorflow.Keras.Engine
  13. {
  14. public partial class Model
  15. {
  16. /// <summary>
  17. /// Returns the loss value & metrics values for the model in test mode.
  18. /// </summary>
  19. /// <param name="x"></param>
  20. /// <param name="y"></param>
  21. /// <param name="batch_size"></param>
  22. /// <param name="verbose"></param>
  23. /// <param name="steps"></param>
  24. /// <param name="max_queue_size"></param>
  25. /// <param name="workers"></param>
  26. /// <param name="use_multiprocessing"></param>
  27. /// <param name="return_dict"></param>
  28. /// <param name="is_val"></param>
  29. public Dictionary<string, float> evaluate(NDArray x, NDArray y,
  30. int batch_size = -1,
  31. int verbose = 1,
  32. int steps = -1,
  33. int max_queue_size = 10,
  34. int workers = 1,
  35. bool use_multiprocessing = false,
  36. bool return_dict = false,
  37. bool is_val = false
  38. )
  39. {
  40. if (x.dims[0] != y.dims[0])
  41. {
  42. throw new InvalidArgumentError(
  43. $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
  44. }
  45. var data_handler = new DataHandler(new DataHandlerArgs
  46. {
  47. X = x,
  48. Y = y,
  49. BatchSize = batch_size,
  50. StepsPerEpoch = steps,
  51. InitialEpoch = 0,
  52. Epochs = 1,
  53. MaxQueueSize = max_queue_size,
  54. Workers = workers,
  55. UseMultiprocessing = use_multiprocessing,
  56. Model = this,
  57. StepsPerExecution = _steps_per_execution
  58. });
  59. var callbacks = new CallbackList(new CallbackParams
  60. {
  61. Model = this,
  62. Verbose = verbose,
  63. Steps = data_handler.Inferredsteps
  64. });
  65. callbacks.on_test_begin();
  66. //Dictionary<string, float>? logs = null;
  67. var logs = new Dictionary<string, float>();
  68. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  69. {
  70. reset_metrics();
  71. // data_handler.catch_stop_iteration();
  72. foreach (var step in data_handler.steps())
  73. {
  74. callbacks.on_test_batch_begin(step);
  75. logs = test_function(data_handler, iterator);
  76. var end_step = step + data_handler.StepIncrement;
  77. if (is_val == false)
  78. callbacks.on_test_batch_end(end_step, logs);
  79. }
  80. }
  81. var results = new Dictionary<string, float>();
  82. foreach (var log in logs)
  83. {
  84. results[log.Key] = log.Value;
  85. }
  86. return results;
  87. }
  88. public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false)
  89. {
  90. var data_handler = new DataHandler(new DataHandlerArgs
  91. {
  92. X = new Tensors(x),
  93. Y = y,
  94. Model = this,
  95. StepsPerExecution = _steps_per_execution
  96. });
  97. var callbacks = new CallbackList(new CallbackParams
  98. {
  99. Model = this,
  100. Verbose = verbose,
  101. Steps = data_handler.Inferredsteps
  102. });
  103. callbacks.on_test_begin();
  104. Dictionary<string, float> logs = null;
  105. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  106. {
  107. reset_metrics();
  108. callbacks.on_epoch_begin(epoch);
  109. // data_handler.catch_stop_iteration();
  110. foreach (var step in data_handler.steps())
  111. {
  112. callbacks.on_test_batch_begin(step);
  113. logs = test_step_multi_inputs_function(data_handler, iterator);
  114. var end_step = step + data_handler.StepIncrement;
  115. if (is_val == false)
  116. callbacks.on_test_batch_end(end_step, logs);
  117. }
  118. }
  119. var results = new Dictionary<string, float>();
  120. foreach (var log in logs)
  121. {
  122. results[log.Key] = log.Value;
  123. }
  124. return results;
  125. }
  126. public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
  127. {
  128. var data_handler = new DataHandler(new DataHandlerArgs
  129. {
  130. Dataset = x,
  131. Model = this,
  132. StepsPerExecution = _steps_per_execution
  133. });
  134. var callbacks = new CallbackList(new CallbackParams
  135. {
  136. Model = this,
  137. Verbose = verbose,
  138. Steps = data_handler.Inferredsteps
  139. });
  140. callbacks.on_test_begin();
  141. Dictionary<string, float> logs = null;
  142. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  143. {
  144. reset_metrics();
  145. callbacks.on_epoch_begin(epoch);
  146. // data_handler.catch_stop_iteration();
  147. foreach (var step in data_handler.steps())
  148. {
  149. callbacks.on_test_batch_begin(step);
  150. logs = test_function(data_handler, iterator);
  151. var end_step = step + data_handler.StepIncrement;
  152. if (is_val == false)
  153. callbacks.on_test_batch_end(end_step, logs);
  154. }
  155. }
  156. var results = new Dictionary<string, float>();
  157. foreach (var log in logs)
  158. {
  159. results[log.Key] = log.Value;
  160. }
  161. return results;
  162. }
  163. Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
  164. {
  165. var data = iterator.next();
  166. var outputs = test_step(data_handler, data[0], data[1]);
  167. tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
  168. return outputs;
  169. }
  170. Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
  171. {
  172. var data = iterator.next();
  173. var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
  174. var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
  175. tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
  176. return outputs;
  177. }
  178. Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y)
  179. {
  180. (x, y) = data_handler.DataAdapter.Expand1d(x, y);
  181. var y_pred = Apply(x, training: false);
  182. var loss = compiled_loss.Call(y, y_pred);
  183. compiled_metrics.update_state(y, y_pred);
  184. return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2);
  185. }
  186. }
  187. }