You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Evaluate.cs 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. using Tensorflow.NumPy;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine.DataAdapters;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow.Keras.Engine
  9. {
  10. public partial class Model
  11. {
  12. /// <summary>
  13. /// Returns the loss value & metrics values for the model in test mode.
  14. /// </summary>
  15. /// <param name="x"></param>
  16. /// <param name="y"></param>
  17. /// <param name="batch_size"></param>
  18. /// <param name="verbose"></param>
  19. /// <param name="steps"></param>
  20. /// <param name="max_queue_size"></param>
  21. /// <param name="workers"></param>
  22. /// <param name="use_multiprocessing"></param>
  23. /// <param name="return_dict"></param>
  24. public void evaluate(NDArray x, NDArray y,
  25. int batch_size = -1,
  26. int verbose = 1,
  27. int steps = -1,
  28. int max_queue_size = 10,
  29. int workers = 1,
  30. bool use_multiprocessing = false,
  31. bool return_dict = false)
  32. {
  33. data_handler = new DataHandler(new DataHandlerArgs
  34. {
  35. X = x,
  36. Y = y,
  37. BatchSize = batch_size,
  38. StepsPerEpoch = steps,
  39. InitialEpoch = 0,
  40. Epochs = 1,
  41. MaxQueueSize = max_queue_size,
  42. Workers = workers,
  43. UseMultiprocessing = use_multiprocessing,
  44. Model = this,
  45. StepsPerExecution = _steps_per_execution
  46. });
  47. Binding.tf_output_redirect.WriteLine($"Testing...");
  48. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  49. {
  50. reset_metrics();
  51. // callbacks.on_epoch_begin(epoch)
  52. // data_handler.catch_stop_iteration();
  53. IEnumerable<(string, Tensor)> results = null;
  54. foreach (var step in data_handler.steps())
  55. {
  56. // callbacks.on_train_batch_begin(step)
  57. results = test_function(iterator);
  58. }
  59. Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
  60. }
  61. }
  62. public void evaluate(IDatasetV2 x)
  63. {
  64. data_handler = new DataHandler(new DataHandlerArgs
  65. {
  66. Dataset = x,
  67. Model = this,
  68. StepsPerExecution = _steps_per_execution
  69. });
  70. Binding.tf_output_redirect.WriteLine($"Testing...");
  71. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  72. {
  73. reset_metrics();
  74. // callbacks.on_epoch_begin(epoch)
  75. // data_handler.catch_stop_iteration();
  76. IEnumerable<(string, Tensor)> results = null;
  77. foreach (var step in data_handler.steps())
  78. {
  79. // callbacks.on_train_batch_begin(step)
  80. results = test_function(iterator);
  81. }
  82. Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
  83. }
  84. }
  85. IEnumerable<(string, Tensor)> test_function(OwnedIterator iterator)
  86. {
  87. var data = iterator.next();
  88. var outputs = test_step(data[0], data[1]);
  89. tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
  90. return outputs;
  91. }
  92. List<(string, Tensor)> test_step(Tensor x, Tensor y)
  93. {
  94. (x, y) = data_handler.DataAdapter.Expand1d(x, y);
  95. var y_pred = Apply(x, training: false);
  96. var loss = compiled_loss.Call(y, y_pred);
  97. compiled_metrics.update_state(y, y_pred);
  98. return metrics.Select(x => (x.Name, x.result())).ToList();
  99. }
  100. }
  101. }