You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ProgbarLogger.cs 3.3 kB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. using System.Diagnostics;
  2. using Tensorflow.Keras.Engine;
  3. namespace Tensorflow.Keras.Callbacks
  4. {
  5. public class ProgbarLogger : ICallback
  6. {
  7. bool _called_in_fit = false;
  8. int seen = 0;
  9. CallbackParams _parameters;
  10. Stopwatch _sw;
  11. public Dictionary<string, List<float>> history { get; set; }
  12. public ProgbarLogger(CallbackParams parameters)
  13. {
  14. _parameters = parameters;
  15. }
  16. public void on_train_begin()
  17. {
  18. _called_in_fit = true;
  19. _sw = new Stopwatch();
  20. }
  21. public void on_train_end() { }
  22. public void on_test_begin()
  23. {
  24. _sw = new Stopwatch();
  25. }
  26. public void on_epoch_begin(int epoch)
  27. {
  28. _reset_progbar();
  29. _maybe_init_progbar();
  30. Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{_parameters.Epochs:D3}");
  31. }
  32. public void on_train_batch_begin(long step)
  33. {
  34. _sw.Restart();
  35. }
  36. public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
  37. {
  38. _sw.Stop();
  39. var elapse = _sw.ElapsedMilliseconds;
  40. var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {(float)x.Value:F6}"));
  41. var progress = "";
  42. var length = 30.0 / _parameters.Steps;
  43. for (int i = 0; i < Math.Floor(end_step * length - 1); i++)
  44. progress += "=";
  45. if (progress.Length < 28)
  46. progress += ">";
  47. else
  48. progress += "=";
  49. var remaining = "";
  50. for (int i = 1; i < 30 - progress.Length; i++)
  51. remaining += ".";
  52. Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} [{progress}{remaining}] - {elapse}ms/step - {results}");
  53. if (!Console.IsOutputRedirected)
  54. {
  55. Console.CursorLeft = 0;
  56. }
  57. }
  58. public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
  59. {
  60. Console.WriteLine();
  61. }
  62. void _reset_progbar()
  63. {
  64. seen = 0;
  65. }
  66. void _maybe_init_progbar()
  67. {
  68. }
  69. public void on_predict_begin()
  70. {
  71. _reset_progbar();
  72. _maybe_init_progbar();
  73. }
  74. public void on_predict_batch_begin(long step)
  75. {
  76. }
  77. public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
  78. {
  79. }
  80. public void on_predict_end()
  81. {
  82. }
  83. public void on_test_batch_begin(long step)
  84. {
  85. _sw.Restart();
  86. }
  87. public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
  88. {
  89. _sw.Stop();
  90. var elapse = _sw.ElapsedMilliseconds;
  91. var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}"));
  92. Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}");
  93. if (!Console.IsOutputRedirected)
  94. {
  95. Console.CursorLeft = 0;
  96. }
  97. }
  98. }
  99. }