You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Fit.cs 4.9 kB

2 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. using Tensorflow.NumPy;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine.DataAdapters;
  7. using System.Diagnostics;
  8. using Tensorflow.Keras.Callbacks;
  9. using System.Data;
  10. namespace Tensorflow.Keras.Engine
  11. {
  12. public partial class Model
  13. {
  14. /// <summary>
  15. /// Trains the model for a fixed number of epochs (iterations on a dataset).
  16. /// </summary>
  17. /// <param name="x"></param>
  18. /// <param name="y"></param>
  19. /// <param name="batch_size"></param>
  20. /// <param name="epochs"></param>
  21. /// <param name="verbose"></param>
  22. /// <param name="validation_split"></param>
  23. /// <param name="shuffle"></param>
  24. public ICallback fit(NDArray x, NDArray y,
  25. int batch_size = -1,
  26. int epochs = 1,
  27. int verbose = 1,
  28. float validation_split = 0f,
  29. bool shuffle = true,
  30. int initial_epoch = 0,
  31. int max_queue_size = 10,
  32. int workers = 1,
  33. bool use_multiprocessing = false)
  34. {
  35. if (x.dims[0] != y.dims[0])
  36. {
  37. throw new InvalidArgumentError(
  38. $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
  39. }
  40. int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
  41. var train_x = x[new Slice(0, train_count)];
  42. var train_y = y[new Slice(0, train_count)];
  43. var val_x = x[new Slice(train_count)];
  44. var val_y = y[new Slice(train_count)];
  45. var data_handler = new DataHandler(new DataHandlerArgs
  46. {
  47. X = train_x,
  48. Y = train_y,
  49. BatchSize = batch_size,
  50. InitialEpoch = initial_epoch,
  51. Epochs = epochs,
  52. Shuffle = shuffle,
  53. MaxQueueSize = max_queue_size,
  54. Workers = workers,
  55. UseMultiprocessing = use_multiprocessing,
  56. Model = this,
  57. StepsPerExecution = _steps_per_execution
  58. });
  59. return FitInternal(data_handler, epochs, verbose);
  60. }
  61. public History fit(IDatasetV2 dataset,
  62. IDatasetV2 validation_data = null,
  63. int batch_size = -1,
  64. int epochs = 1,
  65. int verbose = 1,
  66. float validation_split = 0f,
  67. bool shuffle = true,
  68. int initial_epoch = 0,
  69. int max_queue_size = 10,
  70. int workers = 1,
  71. bool use_multiprocessing = false)
  72. {
  73. var data_handler = new DataHandler(new DataHandlerArgs
  74. {
  75. Dataset = dataset,
  76. BatchSize = batch_size,
  77. InitialEpoch = initial_epoch,
  78. Epochs = epochs,
  79. Shuffle = shuffle,
  80. MaxQueueSize = max_queue_size,
  81. Workers = workers,
  82. UseMultiprocessing = use_multiprocessing,
  83. Model = this,
  84. StepsPerExecution = _steps_per_execution
  85. });
  86. return FitInternal(data_handler, epochs, verbose, validation_data: validation_data);
  87. }
  88. History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data = null)
  89. {
  90. stop_training = false;
  91. _train_counter.assign(0);
  92. var callbacks = new CallbackList(new CallbackParams
  93. {
  94. Model = this,
  95. Verbose = verbose,
  96. Epochs = epochs,
  97. Steps = data_handler.Inferredsteps
  98. });
  99. callbacks.on_train_begin();
  100. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  101. {
  102. reset_metrics();
  103. callbacks.on_epoch_begin(epoch);
  104. // data_handler.catch_stop_iteration();
  105. var logs = new Dictionary<string, float>();
  106. foreach (var step in data_handler.steps())
  107. {
  108. callbacks.on_train_batch_begin(step);
  109. logs = train_step_function(data_handler, iterator);
  110. var end_step = step + data_handler.StepIncrement;
  111. callbacks.on_train_batch_end(end_step, logs);
  112. }
  113. if (validation_data != null)
  114. {
  115. var val_logs = evaluate(validation_data);
  116. foreach(var log in val_logs)
  117. {
  118. logs["val_" + log.Key] = log.Value;
  119. }
  120. }
  121. callbacks.on_epoch_end(epoch, logs);
  122. GC.Collect();
  123. GC.WaitForPendingFinalizers();
  124. }
  125. return callbacks.History;
  126. }
  127. }
  128. }