You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Evaluate.cs 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. using Tensorflow.NumPy;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine.DataAdapters;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow.Keras.Engine
  9. {
  10. public partial class Model
  11. {
  12. /// <summary>
  13. /// Returns the loss value & metrics values for the model in test mode.
  14. /// </summary>
  15. /// <param name="x"></param>
  16. /// <param name="y"></param>
  17. /// <param name="batch_size"></param>
  18. /// <param name="verbose"></param>
  19. /// <param name="steps"></param>
  20. /// <param name="max_queue_size"></param>
  21. /// <param name="workers"></param>
  22. /// <param name="use_multiprocessing"></param>
  23. /// <param name="return_dict"></param>
  24. public void evaluate(NDArray x, NDArray y,
  25. int batch_size = -1,
  26. int verbose = 1,
  27. int steps = -1,
  28. int max_queue_size = 10,
  29. int workers = 1,
  30. bool use_multiprocessing = false,
  31. bool return_dict = false)
  32. {
  33. var data_handler = new DataHandler(new DataHandlerArgs
  34. {
  35. X = x,
  36. Y = y,
  37. BatchSize = batch_size,
  38. StepsPerEpoch = steps,
  39. InitialEpoch = 0,
  40. Epochs = 1,
  41. MaxQueueSize = max_queue_size,
  42. Workers = workers,
  43. UseMultiprocessing = use_multiprocessing,
  44. Model = this,
  45. StepsPerExecution = _steps_per_execution
  46. });
  47. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  48. {
  49. reset_metrics();
  50. // callbacks.on_epoch_begin(epoch)
  51. // data_handler.catch_stop_iteration();
  52. IEnumerable<(string, Tensor)> results = null;
  53. foreach (var step in data_handler.steps())
  54. {
  55. // callbacks.on_train_batch_begin(step)
  56. results = test_function(data_handler, iterator);
  57. }
  58. }
  59. }
  60. public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
  61. {
  62. var data_handler = new DataHandler(new DataHandlerArgs
  63. {
  64. Dataset = x,
  65. Model = this,
  66. StepsPerExecution = _steps_per_execution
  67. });
  68. IEnumerable<(string, Tensor)> logs = null;
  69. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  70. {
  71. reset_metrics();
  72. // callbacks.on_epoch_begin(epoch)
  73. // data_handler.catch_stop_iteration();
  74. foreach (var step in data_handler.steps())
  75. {
  76. // callbacks.on_train_batch_begin(step)
  77. logs = test_function(data_handler, iterator);
  78. }
  79. }
  80. return logs.Select(x => new KeyValuePair<string, float>(x.Item1, (float)x.Item2)).ToArray();
  81. }
  82. IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator)
  83. {
  84. var data = iterator.next();
  85. var outputs = test_step(data_handler, data[0], data[1]);
  86. tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
  87. return outputs;
  88. }
  89. List<(string, Tensor)> test_step(DataHandler data_handler, Tensor x, Tensor y)
  90. {
  91. (x, y) = data_handler.DataAdapter.Expand1d(x, y);
  92. var y_pred = Apply(x, training: false);
  93. var loss = compiled_loss.Call(y, y_pred);
  94. compiled_metrics.update_state(y, y_pred);
  95. return metrics.Select(x => (x.Name, x.result())).ToList();
  96. }
  97. }
  98. }