You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CallbackList.cs 2.1 kB

2 years ago
2 years ago
2 years ago
2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow.Keras.Engine;
  5. namespace Tensorflow.Keras.Callbacks;
  6. public class CallbackList
  7. {
  8. // 改成public使得新定义的callback可以加入到callbacks里
  9. public List<ICallback> callbacks = new List<ICallback>();
  10. public History History => callbacks[0] as History;
  11. public CallbackList(CallbackParams parameters)
  12. {
  13. callbacks.Add(new History(parameters));
  14. callbacks.Add(new ProgbarLogger(parameters));
  15. }
  16. public void on_train_begin()
  17. {
  18. callbacks.ForEach(x => x.on_train_begin());
  19. }
  20. public void on_test_begin()
  21. {
  22. callbacks.ForEach(x => x.on_test_begin());
  23. }
  24. public void on_epoch_begin(int epoch)
  25. {
  26. callbacks.ForEach(x => x.on_epoch_begin(epoch));
  27. }
  28. public void on_train_batch_begin(long step)
  29. {
  30. callbacks.ForEach(x => x.on_train_batch_begin(step));
  31. }
  32. public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
  33. {
  34. callbacks.ForEach(x => x.on_train_batch_end(end_step, logs));
  35. }
  36. public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
  37. {
  38. callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs));
  39. }
  40. public void on_predict_begin()
  41. {
  42. callbacks.ForEach(x => x.on_predict_begin());
  43. }
  44. public void on_predict_batch_begin(long step)
  45. {
  46. callbacks.ForEach(x => x.on_predict_batch_begin(step));
  47. }
  48. public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
  49. {
  50. callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs));
  51. }
  52. public void on_predict_end()
  53. {
  54. callbacks.ForEach(x => x.on_predict_end());
  55. }
  56. public void on_test_batch_begin(long step)
  57. {
  58. callbacks.ForEach(x => x.on_test_batch_begin(step));
  59. }
  60. public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
  61. {
  62. callbacks.ForEach(x => x.on_test_batch_end(end_step, logs));
  63. }
  64. }