You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Earlystopping.cs 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. using Tensorflow.Keras.Engine;
  2. namespace Tensorflow.Keras.Callbacks;
  3. /// <summary>
  4. /// Stop training when a monitored metric has stopped improving.
  5. /// </summary>
  6. /// <param name="parameters"></param>
  7. /// <param name="monitor"></param>
  8. public class EarlyStopping: ICallback
  9. {
  10. int _paitence;
  11. int _min_delta;
  12. int _verbose;
  13. int _stopped_epoch;
  14. int _wait;
  15. int _best_epoch;
  16. int _start_from_epoch;
  17. float _best;
  18. float _baseline;
  19. string _monitor;
  20. string _mode;
  21. bool _restore_best_weights;
  22. List<IVariableV1>? _best_weights;
  23. CallbackParams _parameters;
  24. public Dictionary<string, List<float>>? history { get; set; }
  25. // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
  26. public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0,
  27. int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false,
  28. int start_from_epoch = 0)
  29. {
  30. _parameters = parameters;
  31. _stopped_epoch = 0;
  32. _wait = 0;
  33. _monitor = monitor;
  34. _paitence = patience;
  35. _verbose = verbose;
  36. _baseline = baseline;
  37. _start_from_epoch = start_from_epoch;
  38. _min_delta = Math.Abs(min_delta);
  39. _restore_best_weights = restore_best_weights;
  40. _mode = mode;
  41. if (mode != "auto" && mode != "min" && mode != "max")
  42. {
  43. Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode);
  44. }
  45. }
  46. public void on_train_begin()
  47. {
  48. _wait = 0;
  49. _stopped_epoch = 0;
  50. _best_epoch = 0;
  51. _best = (float)np.Inf;
  52. }
  53. public void on_epoch_begin(int epoch)
  54. {
  55. }
  56. public void on_train_batch_begin(long step)
  57. {
  58. }
  59. public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
  60. {
  61. }
  62. public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
  63. {
  64. var current = get_monitor_value(epoch_logs);
  65. // If no monitor value exists or still in initial warm-up stage.
  66. if (current == 0f || epoch < _start_from_epoch)
  67. return;
  68. // Restore the weights after first epoch if no progress is ever made.
  69. if (_restore_best_weights && _best_weights == null)
  70. {
  71. _best_weights = _parameters.Model.TrainableWeights;
  72. }
  73. _wait += 1;
  74. if (_is_improvement(current, _best))
  75. {
  76. _best = current;
  77. _best_epoch = epoch;
  78. if (_restore_best_weights)
  79. _best_weights = _parameters.Model.TrainableWeights;
  80. // Only restart wait if we beat both the baseline and our previous best.
  81. if (_baseline == 0f || _is_improvement(current, _baseline))
  82. _wait = 0;
  83. }
  84. // Only check after the first epoch.
  85. if (_wait >= _paitence && epoch > 0)
  86. {
  87. _stopped_epoch = epoch;
  88. _parameters.Model.Stop_training = true;
  89. if (_restore_best_weights && _best_weights != null)
  90. {
  91. if (_verbose > 0)
  92. {
  93. Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
  94. }
  95. }
  96. // Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet.
  97. // TODO(Wanglongzhi2001): implement it.
  98. // _parameters.Model.load_weights(best_weights);
  99. }
  100. }
  101. public void on_train_end()
  102. {
  103. if (_stopped_epoch > 0 && _verbose > 0)
  104. {
  105. Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping");
  106. }
  107. }
  108. public void on_predict_begin() { }
  109. public void on_predict_batch_begin(long step) { }
  110. public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) { }
  111. public void on_predict_end() { }
  112. public void on_test_begin() { }
  113. public void on_test_batch_begin(long step) { }
  114. public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { }
  115. float get_monitor_value(Dictionary<string, float> logs)
  116. {
  117. logs = logs ?? new Dictionary<string, float>();
  118. float monitor_value = logs[_monitor];
  119. if (monitor_value == 0f)
  120. {
  121. Console.WriteLine($"Early stopping conditioned on metric {_monitor} " +
  122. $"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}");
  123. }
  124. return monitor_value;
  125. }
  126. public bool _is_improvement(float monitor_value, float reference_value)
  127. {
  128. bool less_op = (monitor_value - _min_delta) < reference_value;
  129. bool greater_op = (monitor_value - _min_delta) >= reference_value;
  130. if (_mode == "min")
  131. return less_op;
  132. else if (_mode == "max")
  133. return greater_op;
  134. else
  135. {
  136. if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
  137. {
  138. return greater_op;
  139. }
  140. else
  141. return less_op;
  142. }
  143. }
  144. }