You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Fit.cs 4.0 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine.DataAdapters;
  7. namespace Tensorflow.Keras.Engine
  8. {
  9. public partial class Model
  10. {
  11. /// <summary>
  12. /// Trains the model for a fixed number of epochs (iterations on a dataset).
  13. /// </summary>
  14. /// <param name="x"></param>
  15. /// <param name="y"></param>
  16. /// <param name="batch_size"></param>
  17. /// <param name="epochs"></param>
  18. /// <param name="verbose"></param>
  19. /// <param name="validation_split"></param>
  20. /// <param name="shuffle"></param>
  21. public void fit(NDArray x, NDArray y,
  22. int batch_size = -1,
  23. int epochs = 1,
  24. int verbose = 1,
  25. float validation_split = 0f,
  26. bool shuffle = true,
  27. int initial_epoch = 0,
  28. int max_queue_size = 10,
  29. int workers = 1,
  30. bool use_multiprocessing = false)
  31. {
  32. int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split));
  33. var train_x = x[new Slice(0, train_count)];
  34. var train_y = y[new Slice(0, train_count)];
  35. var val_x = x[new Slice(train_count)];
  36. var val_y = y[new Slice(train_count)];
  37. data_handler = new DataHandler(new DataHandlerArgs
  38. {
  39. X = train_x,
  40. Y = train_y,
  41. BatchSize = batch_size,
  42. InitialEpoch = initial_epoch,
  43. Epochs = epochs,
  44. Shuffle = shuffle,
  45. MaxQueueSize = max_queue_size,
  46. Workers = workers,
  47. UseMultiprocessing = use_multiprocessing,
  48. Model = this,
  49. StepsPerExecution = _steps_per_execution
  50. });
  51. FitInternal(epochs, verbose);
  52. }
  53. public void fit(IDatasetV2 dataset,
  54. IDatasetV2 validation_data = null,
  55. int batch_size = -1,
  56. int epochs = 1,
  57. int verbose = 1,
  58. float validation_split = 0f,
  59. bool shuffle = true,
  60. int initial_epoch = 0,
  61. int max_queue_size = 10,
  62. int workers = 1,
  63. bool use_multiprocessing = false)
  64. {
  65. data_handler = new DataHandler(new DataHandlerArgs
  66. {
  67. Dataset = dataset,
  68. BatchSize = batch_size,
  69. InitialEpoch = initial_epoch,
  70. Epochs = epochs,
  71. Shuffle = shuffle,
  72. MaxQueueSize = max_queue_size,
  73. Workers = workers,
  74. UseMultiprocessing = use_multiprocessing,
  75. Model = this,
  76. StepsPerExecution = _steps_per_execution
  77. });
  78. FitInternal(epochs, verbose);
  79. }
  80. void FitInternal(int epochs, int verbose)
  81. {
  82. stop_training = false;
  83. _train_counter.assign(0);
  84. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  85. {
  86. // reset_metrics();
  87. // callbacks.on_epoch_begin(epoch)
  88. // data_handler.catch_stop_iteration();
  89. foreach (var step in data_handler.steps())
  90. {
  91. // callbacks.on_train_batch_begin(step)
  92. var results = train_step_function(iterator);
  93. if (verbose == 1)
  94. {
  95. var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
  96. Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
  97. }
  98. }
  99. GC.Collect();
  100. GC.WaitForPendingFinalizers();
  101. }
  102. }
  103. }
  104. }